Skip to content

API Reference

tnh_scholar

TNH Scholar: Text Processing and Analysis Tools

TNH Scholar is an AI-driven project designed to explore, query, process and translate the teachings of Thich Nhat Hanh and other Plum Village Dharma Teachers. The project aims to create a resource for practitioners and scholars to deeply engage with mindfulness and spiritual wisdom through natural language processing and machine learning models.

Core Features
  • Audio transcription and processing
  • Multi-lingual text processing and translation
  • Pattern-based text analysis
  • OCR processing for historical documents
  • CLI tools for batch processing
Package Structure
  • tnh_scholar/
  • CLI_tools/ - Command line interface tools
  • audio_processing/ - Audio file handling and transcription
  • journal_processing/ - Journal and publication processing
  • ocr_processing/ - Optical character recognition tools
  • openai_interface/ - OpenAI API integration
  • text_processing/ - Core text processing utilities
  • video_processing/ - Video file handling and transcription
  • utils/ - Shared utility functions
  • xml_processing/ - XML parsing and generation
Environment Configuration
  • The package uses environment variables for configuration, including:
  • TNH_PATTERN_DIR - Directory for text processing patterns
  • OPENAI_API_KEY - OpenAI API authentication
  • GOOGLE_VISION_KEY - Google Cloud Vision API key for OCR
CLI Tools
  • audio-transcribe - Audio file transcription utility
  • tnh-fab - Text processing and analysis toolkit

For more information, see: - Documentation: https://aaronksolomon.github.io/tnh-scholar/ - Source: https://github.com/aaronksolomon/tnh-scholar - Issues: https://github.com/aaronksolomon/tnh-scholar/issues

Dependencies
  • Core: click, pydantic, openai, yt-dlp
  • Optional: streamlit (GUI), spacy (NLP), google-cloud-vision (OCR)

TNH_CLI_TOOLS_DIR = TNH_ROOT_SRC_DIR / 'cli_tools' module-attribute

TNH_CONFIG_DIR = Path.home() / '.config' / 'tnh-scholar' module-attribute

TNH_DEFAULT_PATTERN_DIR = TNH_CONFIG_DIR / 'patterns' module-attribute

TNH_LOG_DIR = TNH_CONFIG_DIR / 'logs' module-attribute

TNH_PROJECT_ROOT_DIR = TNH_ROOT_SRC_DIR.resolve().parent.parent module-attribute

TNH_ROOT_SRC_DIR = Path(__file__).resolve().parent module-attribute

__version__ = '0.1.3' module-attribute

ai_text_processing

ai_text_processing

DEFAULT_MIN_SECTION_COUNT = 3 module-attribute
DEFAULT_OPENAI_MODEL = 'gpt-4o' module-attribute
DEFAULT_PARAGRAPH_FORMAT_PATTERN = 'default_xml_paragraph_format' module-attribute
DEFAULT_PUNCTUATE_MODEL = 'gpt-4o' module-attribute
DEFAULT_PUNCTUATE_PATTERN = 'default_punctuate' module-attribute
DEFAULT_PUNCTUATE_STYLE = 'APA' module-attribute
DEFAULT_REVIEW_COUNT = 5 module-attribute
DEFAULT_SECTION_PATTERN = 'default_section' module-attribute
DEFAULT_SECTION_RANGE_VAR = 2 module-attribute
DEFAULT_SECTION_RESULT_MAX_SIZE = 4000 module-attribute
DEFAULT_SECTION_TOKEN_SIZE = 650 module-attribute
DEFAULT_TARGET_LANGUAGE = 'English' module-attribute
DEFAULT_TRANSLATE_CONTEXT_LINES = 3 module-attribute
DEFAULT_TRANSLATE_SEGMENT_SIZE = 20 module-attribute
DEFAULT_TRANSLATE_STYLE = "'American Dharma Teaching'" module-attribute
DEFAULT_TRANSLATION_PATTERN = 'default_line_translation' module-attribute
DEFAULT_TRANSLATION_TARGET_TOKENS = 650 module-attribute
DEFAULT_XML_FORMAT_PATTERN = 'default_xml_format' module-attribute
FOLLOWING_CONTEXT_MARKER = 'FOLLOWING_CONTEXT' module-attribute
PRECEDING_CONTEXT_MARKER = 'PRECEDING_CONTEXT' module-attribute
SECTION_SEGMENT_SIZE_WARNING_LIMIT = 5 module-attribute
TRANSCRIPT_SEGMENT_MARKER = 'TRANSCRIPT_SEGMENT' module-attribute
logger = get_child_logger(__name__) module-attribute
GeneralProcessor
Source code in src/tnh_scholar/ai_text_processing/ai_text_processing.py
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
class GeneralProcessor:
    def __init__(
        self,
        processor: TextProcessor,
        pattern: Pattern,
        source_language: Optional[str] = None,
        review_count: int = DEFAULT_REVIEW_COUNT,
    ):
        """
        Initialize punctuation generator.

        Args:
            text_punctuator: Implementation of TextProcessor
            punctuate_pattern: Pattern object containing punctuation instructions
            section_count: Target number of sections
            review_count: Number of review passes
        """

        self.source_language = source_language
        self.processor = processor
        self.pattern = pattern
        self.review_count = review_count

    def process_text(
        self,
        text: str,
        source_language: Optional[str] = None,
        template_dict: Optional[Dict] = None,
    ) -> str:
        """
        process a text based on a pattern and source language.
        """

        if not source_language:
            if self.source_language:
                source_language = self.source_language
            else:
                source_language = get_language_name(text)

        template_values = {
            "source_language": source_language,
            "review_count": self.review_count,
        }

        if template_dict:
            template_values |= template_dict

        logger.info("Processing text...")
        instructions = self.pattern.apply_template(template_values)

        logger.debug(f"Process instructions:\n{instructions}")

        text = self.processor.process_text(text, instructions)
        logger.info("Processing completed.")

        # normalize newline spacing to two newline between lines and return
        # commented out to allow pattern to dictate newlines:
        # return normalize_newlines(text)
        return text
pattern = pattern instance-attribute
processor = processor instance-attribute
review_count = review_count instance-attribute
source_language = source_language instance-attribute
__init__(processor, pattern, source_language=None, review_count=DEFAULT_REVIEW_COUNT)

Initialize punctuation generator.

Parameters:

Name Type Description Default
text_punctuator

Implementation of TextProcessor

required
punctuate_pattern

Pattern object containing punctuation instructions

required
section_count

Target number of sections

required
review_count int

Number of review passes

DEFAULT_REVIEW_COUNT
Source code in src/tnh_scholar/ai_text_processing/ai_text_processing.py
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
def __init__(
    self,
    processor: TextProcessor,
    pattern: Pattern,
    source_language: Optional[str] = None,
    review_count: int = DEFAULT_REVIEW_COUNT,
):
    """
    Initialize punctuation generator.

    Args:
        text_punctuator: Implementation of TextProcessor
        punctuate_pattern: Pattern object containing punctuation instructions
        section_count: Target number of sections
        review_count: Number of review passes
    """

    self.source_language = source_language
    self.processor = processor
    self.pattern = pattern
    self.review_count = review_count
process_text(text, source_language=None, template_dict=None)

process a text based on a pattern and source language.

Source code in src/tnh_scholar/ai_text_processing/ai_text_processing.py
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
def process_text(
    self,
    text: str,
    source_language: Optional[str] = None,
    template_dict: Optional[Dict] = None,
) -> str:
    """
    process a text based on a pattern and source language.
    """

    if not source_language:
        if self.source_language:
            source_language = self.source_language
        else:
            source_language = get_language_name(text)

    template_values = {
        "source_language": source_language,
        "review_count": self.review_count,
    }

    if template_dict:
        template_values |= template_dict

    logger.info("Processing text...")
    instructions = self.pattern.apply_template(template_values)

    logger.debug(f"Process instructions:\n{instructions}")

    text = self.processor.process_text(text, instructions)
    logger.info("Processing completed.")

    # normalize newline spacing to two newline between lines and return
    # commented out to allow pattern to dictate newlines:
    # return normalize_newlines(text)
    return text
LineTranslator

Translates text line by line while maintaining line numbers and context.

Source code in src/tnh_scholar/ai_text_processing/ai_text_processing.py
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
class LineTranslator:
    """Translates text line by line while maintaining line numbers and context."""

    def __init__(
        self,
        processor: TextProcessor,
        pattern: Pattern,
        review_count: int = DEFAULT_REVIEW_COUNT,
        style: str = DEFAULT_TRANSLATE_STYLE,
        context_lines: int = DEFAULT_TRANSLATE_CONTEXT_LINES,  # Number of context lines before/after
    ):
        """
        Initialize line translator.

        Args:
            processor: Implementation of TextProcessor
            pattern: Pattern object containing translation instructions
            review_count: Number of review passes
            style: Translation style to apply
            context_lines: Number of context lines to include before/after
        """
        self.processor = processor
        self.pattern = pattern
        self.review_count = review_count
        self.style = style
        self.context_lines = context_lines

    def translate_segment(
        self,
        num_text: NumberedText,
        start_line: int,
        end_line: int,
        source_language: Optional[str] = None,
        target_language: Optional[str] = DEFAULT_TARGET_LANGUAGE,
        template_dict: Optional[Dict] = None,
    ) -> str:
        """
        Translate a segment of text with context.

        Args:
            text: Full text to extract segment from
            start_line: Starting line number of segment
            end_line: Ending line number of segment
            source_language: Source language code
            target_language: Target language code (default: English)
            template_dict: Optional additional template values

        Returns:
            Translated text segment with line numbers preserved
        """

        # Extract main segment and context
        lines = num_text.numbered_lines

        # Calculate context ranges
        preceding_start = max(1, start_line - self.context_lines)  # lines start on 1.
        following_end = min(num_text.end + 1, end_line + self.context_lines)

        # Extract context and segment
        preceding_context = num_text.get_numbered_segment(preceding_start, start_line)
        transcript_segment = num_text.get_numbered_segment(start_line, end_line)
        following_context = num_text.get_numbered_segment(end_line, following_end)

        # build input text
        translation_input = self._build_translation_input(
            preceding_context, transcript_segment, following_context
        )

        # Prepare template values
        template_values = {
            "source_language": source_language,
            "target_language": target_language,
            "review_count": self.review_count,
            "style": self.style,
        }

        if template_dict:
            template_values |= template_dict

        # Get and apply translation instructions
        logger.info(f"Translating segment (lines {start_line}-{end_line})")
        translate_instructions = self.pattern.apply_template(template_values)

        if start_line <= 1:
            logger.debug(
                f"Translate instructions (first segment):\n{translate_instructions}"
            )

        logger.debug(f"Translation input:\n{translation_input}")

        return self.processor.process_text(translation_input, translate_instructions)

    def _build_translation_input(
        self, preceding_context: str, transcript_segment: str, following_context: str
    ) -> str:
        """
        Build input text in required XML-style format.

        Args:
            preceding_context: Context lines before segment
            transcript_segment: Main segment to translate
            following_context: Context lines after segment

        Returns:
            Formatted input text
        """
        parts = []

        # Add preceding context if exists
        if preceding_context:
            parts.extend(
                [
                    PRECEDING_CONTEXT_MARKER,
                    preceding_context,
                    PRECEDING_CONTEXT_MARKER,
                    "",
                ]
            )

        # Add main segment (always required)
        parts.extend(
            [
                TRANSCRIPT_SEGMENT_MARKER,
                transcript_segment,
                TRANSCRIPT_SEGMENT_MARKER,
                "",
            ]
        )

        # Add following context if exists
        if following_context:
            parts.extend(
                [
                    FOLLOWING_CONTEXT_MARKER,
                    following_context,
                    FOLLOWING_CONTEXT_MARKER,
                    "",
                ]
            )

        return "\n".join(parts)

    def translate_text(
        self,
        text: str,
        segment_size: Optional[int] = None,  # Number of lines per segment
        source_language: Optional[str] = None,
        target_language: Optional[str] = None,
        template_dict: Optional[Dict] = None,
    ) -> str:
        """
        Translate entire text in segments while maintaining line continuity.

        Args:
            text: Text to translate
            segment_size: Number of lines per translation segment
            source_language: Source language code
            target_language: Target language code (default: English)
            template_dict: Optional additional template values

        Returns:
            Complete translated text with line numbers preserved
        """

        # Auto-detect language if not specified
        if not source_language:
            source_language = get_language_name(text)

        # Convert text to numbered lines
        num_text = NumberedText(text)
        total_lines = num_text.size

        if not segment_size:
            segment_size = _calculate_segment_size(
                num_text, DEFAULT_TRANSLATION_TARGET_TOKENS
            )

        translated_segments = []

        logger.debug(
            f"Total lines to translate: {total_lines} | Translation segment size: {segment_size}."
        )
        # Process text in segments using segment iteration
        for start_idx, end_idx in num_text.iter_segments(
            segment_size, segment_size // 5
        ):
            translated_segment = self.translate_segment(
                num_text,
                start_idx,
                end_idx,
                source_language,
                target_language,
                template_dict,
            )

            # validate the translated segment
            translated_content = self._extract_content(translated_segment)
            self._validate_segment(translated_content, start_idx, end_idx)

            translated_segments.append(translated_content)

        return "\n".join(translated_segments)

    def _extract_content(self, segment: str) -> str:
        segment = segment.strip()  # remove any filling whitespace
        if segment.startswith(TRANSCRIPT_SEGMENT_MARKER) and segment.endswith(
            TRANSCRIPT_SEGMENT_MARKER
        ):
            return segment[
                len(TRANSCRIPT_SEGMENT_MARKER) : -len(TRANSCRIPT_SEGMENT_MARKER)
            ].strip()
        logger.warning("Translated segment missing transcript_segment tags")
        return segment

    def _validate_segment(
        self, translated_content: str, start_index: int, end_index: int
    ) -> None:
        """
        Validate translated segment format, content, and line number sequence.
        Issues warnings for validation issues rather than raising errors.

        Args:
            translated_segment: Translated text to validate
            start_idx: the staring index of the range (inclusive)
            end_line: then ending index of the range (exclusive)

        Returns:
            str: Content with segment tags removed
        """

        # Validate lines

        lines = translated_content.splitlines()
        line_numbers = []

        start_line = start_index  # inclusive start
        end_line = end_index - 1  # exclusive end

        for line in lines:
            line = line.strip()
            if not line:
                continue

            if ":" not in line:
                logger.warning(f"Invalid line format: {line}")
                continue

            try:
                line_num = int(line[: line.index(":")])
                if line_num < 0:
                    logger.warning(f"Invalid line number: {line}")
                    continue
                line_numbers.append(line_num)
            except ValueError:
                logger.warning(f"Line number parsing failed: {line}")
                continue

        # Validate sequence
        if not line_numbers:
            logger.warning("No valid line numbers found")
        else:
            if line_numbers[0] != start_line:
                logger.warning(
                    f"First line number {line_numbers[0]} doesn't match expected {start_line}"
                )

            if line_numbers[-1] != end_line:
                logger.warning(
                    f"Last line number {line_numbers[-1]} doesn't match expected {end_line}"
                )

            expected = set(range(start_line, end_line + 1))
            if missing := expected - set(line_numbers):
                logger.warning(f"Missing line numbers in sequence: {missing}")

        logger.debug(f"Validated {len(lines)} lines from {start_line} to {end_line}")
context_lines = context_lines instance-attribute
pattern = pattern instance-attribute
processor = processor instance-attribute
review_count = review_count instance-attribute
style = style instance-attribute
__init__(processor, pattern, review_count=DEFAULT_REVIEW_COUNT, style=DEFAULT_TRANSLATE_STYLE, context_lines=DEFAULT_TRANSLATE_CONTEXT_LINES)

Initialize line translator.

Parameters:

Name Type Description Default
processor TextProcessor

Implementation of TextProcessor

required
pattern Pattern

Pattern object containing translation instructions

required
review_count int

Number of review passes

DEFAULT_REVIEW_COUNT
style str

Translation style to apply

DEFAULT_TRANSLATE_STYLE
context_lines int

Number of context lines to include before/after

DEFAULT_TRANSLATE_CONTEXT_LINES
Source code in src/tnh_scholar/ai_text_processing/ai_text_processing.py
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
def __init__(
    self,
    processor: TextProcessor,
    pattern: Pattern,
    review_count: int = DEFAULT_REVIEW_COUNT,
    style: str = DEFAULT_TRANSLATE_STYLE,
    context_lines: int = DEFAULT_TRANSLATE_CONTEXT_LINES,  # Number of context lines before/after
):
    """
    Initialize line translator.

    Args:
        processor: Implementation of TextProcessor
        pattern: Pattern object containing translation instructions
        review_count: Number of review passes
        style: Translation style to apply
        context_lines: Number of context lines to include before/after
    """
    self.processor = processor
    self.pattern = pattern
    self.review_count = review_count
    self.style = style
    self.context_lines = context_lines
translate_segment(num_text, start_line, end_line, source_language=None, target_language=DEFAULT_TARGET_LANGUAGE, template_dict=None)

Translate a segment of text with context.

Parameters:

Name Type Description Default
text

Full text to extract segment from

required
start_line int

Starting line number of segment

required
end_line int

Ending line number of segment

required
source_language Optional[str]

Source language code

None
target_language Optional[str]

Target language code (default: English)

DEFAULT_TARGET_LANGUAGE
template_dict Optional[Dict]

Optional additional template values

None

Returns:

Type Description
str

Translated text segment with line numbers preserved

Source code in src/tnh_scholar/ai_text_processing/ai_text_processing.py
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
def translate_segment(
    self,
    num_text: NumberedText,
    start_line: int,
    end_line: int,
    source_language: Optional[str] = None,
    target_language: Optional[str] = DEFAULT_TARGET_LANGUAGE,
    template_dict: Optional[Dict] = None,
) -> str:
    """
    Translate a segment of text with context.

    Args:
        text: Full text to extract segment from
        start_line: Starting line number of segment
        end_line: Ending line number of segment
        source_language: Source language code
        target_language: Target language code (default: English)
        template_dict: Optional additional template values

    Returns:
        Translated text segment with line numbers preserved
    """

    # Extract main segment and context
    lines = num_text.numbered_lines

    # Calculate context ranges
    preceding_start = max(1, start_line - self.context_lines)  # lines start on 1.
    following_end = min(num_text.end + 1, end_line + self.context_lines)

    # Extract context and segment
    preceding_context = num_text.get_numbered_segment(preceding_start, start_line)
    transcript_segment = num_text.get_numbered_segment(start_line, end_line)
    following_context = num_text.get_numbered_segment(end_line, following_end)

    # build input text
    translation_input = self._build_translation_input(
        preceding_context, transcript_segment, following_context
    )

    # Prepare template values
    template_values = {
        "source_language": source_language,
        "target_language": target_language,
        "review_count": self.review_count,
        "style": self.style,
    }

    if template_dict:
        template_values |= template_dict

    # Get and apply translation instructions
    logger.info(f"Translating segment (lines {start_line}-{end_line})")
    translate_instructions = self.pattern.apply_template(template_values)

    if start_line <= 1:
        logger.debug(
            f"Translate instructions (first segment):\n{translate_instructions}"
        )

    logger.debug(f"Translation input:\n{translation_input}")

    return self.processor.process_text(translation_input, translate_instructions)
translate_text(text, segment_size=None, source_language=None, target_language=None, template_dict=None)

Translate entire text in segments while maintaining line continuity.

Parameters:

Name Type Description Default
text str

Text to translate

required
segment_size Optional[int]

Number of lines per translation segment

None
source_language Optional[str]

Source language code

None
target_language Optional[str]

Target language code (default: English)

None
template_dict Optional[Dict]

Optional additional template values

None

Returns:

Type Description
str

Complete translated text with line numbers preserved

Source code in src/tnh_scholar/ai_text_processing/ai_text_processing.py
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
def translate_text(
    self,
    text: str,
    segment_size: Optional[int] = None,  # Number of lines per segment
    source_language: Optional[str] = None,
    target_language: Optional[str] = None,
    template_dict: Optional[Dict] = None,
) -> str:
    """
    Translate entire text in segments while maintaining line continuity.

    Args:
        text: Text to translate
        segment_size: Number of lines per translation segment
        source_language: Source language code
        target_language: Target language code (default: English)
        template_dict: Optional additional template values

    Returns:
        Complete translated text with line numbers preserved
    """

    # Auto-detect language if not specified
    if not source_language:
        source_language = get_language_name(text)

    # Convert text to numbered lines
    num_text = NumberedText(text)
    total_lines = num_text.size

    if not segment_size:
        segment_size = _calculate_segment_size(
            num_text, DEFAULT_TRANSLATION_TARGET_TOKENS
        )

    translated_segments = []

    logger.debug(
        f"Total lines to translate: {total_lines} | Translation segment size: {segment_size}."
    )
    # Process text in segments using segment iteration
    for start_idx, end_idx in num_text.iter_segments(
        segment_size, segment_size // 5
    ):
        translated_segment = self.translate_segment(
            num_text,
            start_idx,
            end_idx,
            source_language,
            target_language,
            template_dict,
        )

        # validate the translated segment
        translated_content = self._extract_content(translated_segment)
        self._validate_segment(translated_content, start_idx, end_idx)

        translated_segments.append(translated_content)

    return "\n".join(translated_segments)
LocalPatternManager

A simple singleton implementation of PatternManager that ensures only one instance is created and reused throughout the application lifecycle.

This class wraps the PatternManager to provide efficient pattern loading by maintaining a single reusable instance.

Attributes:

Name Type Description
_instance Optional[SingletonPatternManager]

The singleton instance

_pattern_manager Optional[PatternManager]

The wrapped PatternManager instance

Source code in src/tnh_scholar/ai_text_processing/ai_text_processing.py
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
class LocalPatternManager:
    """
    A simple singleton implementation of PatternManager that ensures only one instance
    is created and reused throughout the application lifecycle.

    This class wraps the PatternManager to provide efficient pattern loading by
    maintaining a single reusable instance.

    Attributes:
        _instance (Optional[SingletonPatternManager]): The singleton instance
        _pattern_manager (Optional[PatternManager]): The wrapped PatternManager instance
    """

    _instance: Optional["LocalPatternManager"] = None

    def __new__(cls) -> "LocalPatternManager":
        """
        Create or return the singleton instance.

        Returns:
            SingletonPatternManager: The singleton instance
        """
        if cls._instance is None:
            cls._instance = super().__new__(cls)
            cls._instance._pattern_manager = None
        return cls._instance

    @property
    def pattern_manager(self) -> "PatternManager":
        """
        Lazy initialization of the PatternManager instance.

        Returns:
            PatternManager: The wrapped PatternManager instance

        Raises:
            RuntimeError: If PATTERN_REPO is not properly configured
        """
        if self._pattern_manager is None:  # type: ignore
            try:
                load_dotenv()
                if pattern_path_name := os.getenv("TNH_PATTERN_DIR"):
                    pattern_dir = Path(pattern_path_name)
                    logger.debug(f"pattern dir: {pattern_path_name}")
                else:
                    pattern_dir = TNH_DEFAULT_PATTERN_DIR
                self._pattern_manager = PatternManager(pattern_dir)
            except ImportError as err:
                raise RuntimeError(
                    "Failed to initialize PatternManager. Ensure pattern_manager "
                    f"module and PATTERN_REPO are properly configured: {err}"
                ) from err
        return self._pattern_manager
pattern_manager property

Lazy initialization of the PatternManager instance.

Returns:

Name Type Description
PatternManager PatternManager

The wrapped PatternManager instance

Raises:

Type Description
RuntimeError

If PATTERN_REPO is not properly configured

__new__()

Create or return the singleton instance.

Returns:

Name Type Description
SingletonPatternManager LocalPatternManager

The singleton instance

Source code in src/tnh_scholar/ai_text_processing/ai_text_processing.py
71
72
73
74
75
76
77
78
79
80
81
def __new__(cls) -> "LocalPatternManager":
    """
    Create or return the singleton instance.

    Returns:
        SingletonPatternManager: The singleton instance
    """
    if cls._instance is None:
        cls._instance = super().__new__(cls)
        cls._instance._pattern_manager = None
    return cls._instance
OpenAIProcessor

Bases: TextProcessor

OpenAI-based text processor implementation.

Source code in src/tnh_scholar/ai_text_processing/ai_text_processing.py
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
class OpenAIProcessor(TextProcessor):
    """OpenAI-based text processor implementation."""

    def __init__(self, model: Optional[str] = None, max_tokens: int = 0):
        if not model:
            model = DEFAULT_OPENAI_MODEL
        self.model = model
        self.max_tokens = max_tokens

    def process_text(
        self,
        text: str,
        instructions: str,
        response_format: Optional[Type[ResponseFormat]] = None,
        max_tokens: int = 0,
        **kwargs,
    ) -> Union[str, ResponseFormat]:
        """Process text using OpenAI API with optional structured output."""

        if max_tokens == 0 and self.max_tokens > 0:
            max_tokens = self.max_tokens

        return openai_process_text(
            text,
            instructions,
            model=self.model,
            max_tokens=max_tokens,
            response_format=response_format,
            **kwargs,
        )
max_tokens = max_tokens instance-attribute
model = model instance-attribute
__init__(model=None, max_tokens=0)
Source code in src/tnh_scholar/ai_text_processing/ai_text_processing.py
152
153
154
155
156
def __init__(self, model: Optional[str] = None, max_tokens: int = 0):
    if not model:
        model = DEFAULT_OPENAI_MODEL
    self.model = model
    self.max_tokens = max_tokens
process_text(text, instructions, response_format=None, max_tokens=0, **kwargs)

Process text using OpenAI API with optional structured output.

Source code in src/tnh_scholar/ai_text_processing/ai_text_processing.py
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
def process_text(
    self,
    text: str,
    instructions: str,
    response_format: Optional[Type[ResponseFormat]] = None,
    max_tokens: int = 0,
    **kwargs,
) -> Union[str, ResponseFormat]:
    """Process text using OpenAI API with optional structured output."""

    if max_tokens == 0 and self.max_tokens > 0:
        max_tokens = self.max_tokens

    return openai_process_text(
        text,
        instructions,
        model=self.model,
        max_tokens=max_tokens,
        response_format=response_format,
        **kwargs,
    )
ProcessedSection dataclass

Represents a processed section of text with its metadata.

Source code in src/tnh_scholar/ai_text_processing/ai_text_processing.py
111
112
113
114
115
116
117
118
119
120
@dataclass
class ProcessedSection:
    """Represents a processed section of text with its metadata."""

    title: str
    original_text: str
    processed_text: str
    start_line: int
    end_line: int
    metadata: Dict = field(default_factory=dict)
end_line instance-attribute
metadata = field(default_factory=dict) class-attribute instance-attribute
original_text instance-attribute
processed_text instance-attribute
start_line instance-attribute
title instance-attribute
__init__(title, original_text, processed_text, start_line, end_line, metadata=dict())
SectionParser

Generates structured section breakdowns of text content.

Source code in src/tnh_scholar/ai_text_processing/ai_text_processing.py
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
class SectionParser:
    """Generates structured section breakdowns of text content."""

    def __init__(
        self,
        section_scanner: TextProcessor,
        section_pattern: Pattern,
        review_count: int = DEFAULT_REVIEW_COUNT,
    ):
        """
        Initialize section generator.

        Args:
            processor: Implementation of TextProcessor
            pattern: Pattern object containing section generation instructions
            max_tokens: Maximum tokens for response
            section_count: Target number of sections
            review_count: Number of review passes
        """
        self.section_scanner = section_scanner
        self.section_pattern = section_pattern
        self.review_count = review_count

    def find_sections(
        self,
        text: str,
        source_language: Optional[str] = None,
        section_count_target: Optional[int] = None,
        segment_size_target: Optional[int] = None,
        template_dict: Optional[Dict[str, str]] = None,
    ) -> TextObject:
        """
        Generate section breakdown of input text. The text must be split up by newlines.

        Args:
            text: Input text to process
            source_language: ISO 639-1 language code, or None for autodetection
            section_count_target: the target for the number of sections to find
            segment_size_target: the target for the number of lines per section
                (if section_count_target is specified, this value will be set to generate correct segments)
            template_dict: Optional additional template variables

        Returns:
            TextObject containing section breakdown

        Raises:
            ValidationError: If response doesn't match TextObject schema
        """

        # Prepare numbered text, each line is numbered
        num_text = NumberedText(text)

        if num_text.size < SECTION_SEGMENT_SIZE_WARNING_LIMIT:
            logger.warning(
                f"find_sections: Text has only {num_text.size} lines. This may lead to unexpected sectioning results."
            )

        # Get language if not specified
        if not source_language:
            source_language = get_language_name(text)

        # determine section count if not specified
        if not section_count_target:
            segment_size_target, section_count_target = self._get_section_count_info(
                text
            )
        elif not segment_size_target:
            segment_size_target = round(num_text.size / section_count_target)

        section_count_range = self._get_section_count_range(section_count_target)

        # Prepare template variables
        template_values = {
            "source_language": source_language,
            "section_count": section_count_range,
            "line_count": segment_size_target,
            "review_count": self.review_count,
        }

        if template_dict:
            template_values |= template_dict

        # Get and apply processing instructions
        instructions = self.section_pattern.apply_template(template_values)
        logger.debug(f"Finding sections with pattern instructions:\n {instructions}")

        logger.info(
            f"Finding sections for {source_language} text "
            f"(target sections: {section_count_target})"
        )

        # Process text with structured output
        try:
            result = self.section_scanner.process_text(
                str(num_text), instructions, response_format=TextObject
            )

            # Validate section coverage
            self._validate_sections(result.sections, num_text.size)

            return result

        except Exception as e:
            logger.error(f"Section generation failed: {e}")
            raise

    def _get_section_count_info(self, text: str) -> Tuple[int, int]:
        num_text = NumberedText(text)
        segment_size = _calculate_segment_size(num_text, DEFAULT_SECTION_TOKEN_SIZE)
        section_count_target = round(num_text.size / segment_size)
        return segment_size, section_count_target

    def _get_section_count_range(
        self,
        section_count_target: int,
        section_range_var: int = DEFAULT_SECTION_RANGE_VAR,
    ) -> str:
        low = max(1, section_count_target - section_range_var)
        high = section_count_target + section_range_var
        return f"{low}-{high}"

    def _validate_sections(
        self, sections: List[LogicalSection], total_lines: int
    ) -> None:
        """
        Validate section line coverage and ordering. Issues warnings for validation problems
        instead of raising errors.

        Args:
            sections: List of generated sections
            text: Original text
        """

        covered_lines = set()
        last_end = -1

        for section in sections:
            # Check line ordering
            if section.start_line <= last_end:
                logger.warning(
                    f"Section lines should be sequential but found overlap: "
                    f"section starting at {section.start_line} begins before or at "
                    f"previous section end {last_end}"
                )

            # Track line coverage
            section_lines = set(range(section.start_line, section.end_line + 1))
            if section_lines & covered_lines:
                logger.warning(
                    f"Found overlapping lines in section '{section.title_en}'. "
                    f"Each line should belong to exactly one section."
                )
            covered_lines.update(section_lines)

            last_end = section.end_line

        # Check complete coverage
        expected_lines = set(range(1, total_lines + 1))
        if covered_lines != expected_lines:
            missing = sorted(list(expected_lines - covered_lines))
            logger.warning(
                f"Not all lines are covered by sections. "
                f"Missing line numbers: {missing}"
            )
review_count = review_count instance-attribute
section_pattern = section_pattern instance-attribute
section_scanner = section_scanner instance-attribute
__init__(section_scanner, section_pattern, review_count=DEFAULT_REVIEW_COUNT)

Initialize section generator.

Parameters:

Name Type Description Default
processor

Implementation of TextProcessor

required
pattern

Pattern object containing section generation instructions

required
max_tokens

Maximum tokens for response

required
section_count

Target number of sections

required
review_count int

Number of review passes

DEFAULT_REVIEW_COUNT
Source code in src/tnh_scholar/ai_text_processing/ai_text_processing.py
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
def __init__(
    self,
    section_scanner: TextProcessor,
    section_pattern: Pattern,
    review_count: int = DEFAULT_REVIEW_COUNT,
):
    """
    Initialize section generator.

    Args:
        processor: Implementation of TextProcessor
        pattern: Pattern object containing section generation instructions
        max_tokens: Maximum tokens for response
        section_count: Target number of sections
        review_count: Number of review passes
    """
    self.section_scanner = section_scanner
    self.section_pattern = section_pattern
    self.review_count = review_count
find_sections(text, source_language=None, section_count_target=None, segment_size_target=None, template_dict=None)

Generate section breakdown of input text. The text must be split up by newlines.

Parameters:

Name Type Description Default
text str

Input text to process

required
source_language Optional[str]

ISO 639-1 language code, or None for autodetection

None
section_count_target Optional[int]

the target for the number of sections to find

None
segment_size_target Optional[int]

the target for the number of lines per section (if section_count_target is specified, this value will be set to generate correct segments)

None
template_dict Optional[Dict[str, str]]

Optional additional template variables

None

Returns:

Type Description
TextObject

TextObject containing section breakdown

Raises:

Type Description
ValidationError

If response doesn't match TextObject schema

Source code in src/tnh_scholar/ai_text_processing/ai_text_processing.py
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
def find_sections(
    self,
    text: str,
    source_language: Optional[str] = None,
    section_count_target: Optional[int] = None,
    segment_size_target: Optional[int] = None,
    template_dict: Optional[Dict[str, str]] = None,
) -> TextObject:
    """
    Generate section breakdown of input text. The text must be split up by newlines.

    Args:
        text: Input text to process
        source_language: ISO 639-1 language code, or None for autodetection
        section_count_target: the target for the number of sections to find
        segment_size_target: the target for the number of lines per section
            (if section_count_target is specified, this value will be set to generate correct segments)
        template_dict: Optional additional template variables

    Returns:
        TextObject containing section breakdown

    Raises:
        ValidationError: If response doesn't match TextObject schema
    """

    # Prepare numbered text, each line is numbered
    num_text = NumberedText(text)

    if num_text.size < SECTION_SEGMENT_SIZE_WARNING_LIMIT:
        logger.warning(
            f"find_sections: Text has only {num_text.size} lines. This may lead to unexpected sectioning results."
        )

    # Get language if not specified
    if not source_language:
        source_language = get_language_name(text)

    # determine section count if not specified
    if not section_count_target:
        segment_size_target, section_count_target = self._get_section_count_info(
            text
        )
    elif not segment_size_target:
        segment_size_target = round(num_text.size / section_count_target)

    section_count_range = self._get_section_count_range(section_count_target)

    # Prepare template variables
    template_values = {
        "source_language": source_language,
        "section_count": section_count_range,
        "line_count": segment_size_target,
        "review_count": self.review_count,
    }

    if template_dict:
        template_values |= template_dict

    # Get and apply processing instructions
    instructions = self.section_pattern.apply_template(template_values)
    logger.debug(f"Finding sections with pattern instructions:\n {instructions}")

    logger.info(
        f"Finding sections for {source_language} text "
        f"(target sections: {section_count_target})"
    )

    # Process text with structured output
    try:
        result = self.section_scanner.process_text(
            str(num_text), instructions, response_format=TextObject
        )

        # Validate section coverage
        self._validate_sections(result.sections, num_text.size)

        return result

    except Exception as e:
        logger.error(f"Section generation failed: {e}")
        raise
SectionProcessor

Handles section-based XML text processing with configurable output handling.

Source code in src/tnh_scholar/ai_text_processing/ai_text_processing.py
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
class SectionProcessor:
    """Handles section-based XML text processing with configurable output handling."""

    def __init__(
        self,
        processor: TextProcessor,
        pattern: Pattern,
        template_dict: Dict,
        wrap_in_document: bool = True,
    ):
        """
        Initialize the XML section processor.

        Args:
            processor: Implementation of TextProcessor to use
            pattern: Pattern object containing processing instructions
            template_dict: Dictionary for template substitution
            wrap_in_document: Whether to wrap output in <document> tags
        """
        self.processor = processor
        self.pattern = pattern
        self.template_dict = template_dict
        self.wrap_in_document = wrap_in_document

    def process_sections(
        self,
        transcript: str,
        text_object: TextObject,
    ) -> Generator[ProcessedSection, None, None]:
        """
        Process transcript sections and yield results one section at a time.

        Args:
            transcript: Text to process
            text_object: Object containing section definitions

        Yields:
            ProcessedSection: One processed section at a time, containing:
                - title: Section title (English or original language)
                - original_text: Raw text segment
                - processed_text: Processed text content
                - start_line: Starting line number
                - end_line: Ending line number
        """
        numbered_transcript = NumberedText(transcript)
        sections = text_object.sections

        logger.info(
            f"Processing {len(sections)} sections with pattern: {self.pattern.name}"
        )

        for i, section in enumerate(sections, 1):
            logger.info(f"Processing section {i}, '{section.title}':")

            # Get text segment for section
            text_segment = numbered_transcript.get_segment(
                section.start_line, end=section.end_line
            )

            # Prepare template variables
            template_values = {
                "section_title": section.title,
                "source_language": text_object.language,
                "review_count": DEFAULT_REVIEW_COUNT,
            }

            if self.template_dict:
                template_values |= self.template_dict

            # Get and apply processing instructions
            instructions = self.pattern.apply_template(template_values)
            if i <= 1:
                logger.debug(f"Process instructions (first section):\n{instructions}")
            processed_text = self.processor.process_text(text_segment, instructions)

            yield ProcessedSection(
                title=section.title,
                original_text=text_segment,
                processed_text=processed_text,
                start_line=section.start_line,
                end_line=section.end_line,
            )

    def process_paragraphs(
        self,
        transcript: str,
    ) -> Generator[str, None, None]:
        """
        Process transcript by paragraphs (as sections) where paragraphs are assumed to be given as newline separated.

        Args:
            transcript: Text to process

        Returns:
            Generator of lines

        Yields:
            Processed lines as strings
        """
        numbered_transcript = NumberedText(transcript)

        logger.info(f"Processing lines as paragraphs with pattern: {self.pattern.name}")

        for i, line in numbered_transcript:

            # If line is empty or whitespace, continue
            if not line.strip():
                continue

            # Otherwise get and apply processing instructions
            instructions = self.pattern.apply_template(self.template_dict)

            if i <= 1:
                logger.debug(f"Process instructions (first paragraph):\n{instructions}")
            yield self.processor.process_text(line, instructions)
pattern = pattern instance-attribute
processor = processor instance-attribute
template_dict = template_dict instance-attribute
wrap_in_document = wrap_in_document instance-attribute
__init__(processor, pattern, template_dict, wrap_in_document=True)

Initialize the XML section processor.

Parameters:

Name Type Description Default
processor TextProcessor

Implementation of TextProcessor to use

required
pattern Pattern

Pattern object containing processing instructions

required
template_dict Dict

Dictionary for template substitution

required
wrap_in_document bool

Whether to wrap output in tags

True
Source code in src/tnh_scholar/ai_text_processing/ai_text_processing.py
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
def __init__(
    self,
    processor: TextProcessor,
    pattern: Pattern,
    template_dict: Dict,
    wrap_in_document: bool = True,
):
    """
    Initialize the XML section processor.

    Args:
        processor: Implementation of TextProcessor to use
        pattern: Pattern object containing processing instructions
        template_dict: Dictionary for template substitution
        wrap_in_document: Whether to wrap output in <document> tags
    """
    self.processor = processor
    self.pattern = pattern
    self.template_dict = template_dict
    self.wrap_in_document = wrap_in_document
process_paragraphs(transcript)

Process transcript by paragraphs (as sections) where paragraphs are assumed to be given as newline separated.

Parameters:

Name Type Description Default
transcript str

Text to process

required

Returns:

Type Description
None

Generator of lines

Yields:

Type Description
str

Processed lines as strings

Source code in src/tnh_scholar/ai_text_processing/ai_text_processing.py
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
def process_paragraphs(
    self,
    transcript: str,
) -> Generator[str, None, None]:
    """
    Process transcript by paragraphs (as sections) where paragraphs are assumed to be given as newline separated.

    Args:
        transcript: Text to process

    Returns:
        Generator of lines

    Yields:
        Processed lines as strings
    """
    numbered_transcript = NumberedText(transcript)

    logger.info(f"Processing lines as paragraphs with pattern: {self.pattern.name}")

    for i, line in numbered_transcript:

        # If line is empty or whitespace, continue
        if not line.strip():
            continue

        # Otherwise get and apply processing instructions
        instructions = self.pattern.apply_template(self.template_dict)

        if i <= 1:
            logger.debug(f"Process instructions (first paragraph):\n{instructions}")
        yield self.processor.process_text(line, instructions)
process_sections(transcript, text_object)

Process transcript sections and yield results one section at a time.

Parameters:

Name Type Description Default
transcript str

Text to process

required
text_object TextObject

Object containing section definitions

required

Yields:

Name Type Description
ProcessedSection ProcessedSection

One processed section at a time, containing: - title: Section title (English or original language) - original_text: Raw text segment - processed_text: Processed text content - start_line: Starting line number - end_line: Ending line number

Source code in src/tnh_scholar/ai_text_processing/ai_text_processing.py
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
def process_sections(
    self,
    transcript: str,
    text_object: TextObject,
) -> Generator[ProcessedSection, None, None]:
    """
    Process transcript sections and yield results one section at a time.

    Args:
        transcript: Text to process
        text_object: Object containing section definitions

    Yields:
        ProcessedSection: One processed section at a time, containing:
            - title: Section title (English or original language)
            - original_text: Raw text segment
            - processed_text: Processed text content
            - start_line: Starting line number
            - end_line: Ending line number
    """
    numbered_transcript = NumberedText(transcript)
    sections = text_object.sections

    logger.info(
        f"Processing {len(sections)} sections with pattern: {self.pattern.name}"
    )

    for i, section in enumerate(sections, 1):
        logger.info(f"Processing section {i}, '{section.title}':")

        # Get text segment for section
        text_segment = numbered_transcript.get_segment(
            section.start_line, end=section.end_line
        )

        # Prepare template variables
        template_values = {
            "section_title": section.title,
            "source_language": text_object.language,
            "review_count": DEFAULT_REVIEW_COUNT,
        }

        if self.template_dict:
            template_values |= self.template_dict

        # Get and apply processing instructions
        instructions = self.pattern.apply_template(template_values)
        if i <= 1:
            logger.debug(f"Process instructions (first section):\n{instructions}")
        processed_text = self.processor.process_text(text_segment, instructions)

        yield ProcessedSection(
            title=section.title,
            original_text=text_segment,
            processed_text=processed_text,
            start_line=section.start_line,
            end_line=section.end_line,
        )
TextProcessor

Bases: ABC

Abstract base class for text processors that can return Pydantic objects.

Source code in src/tnh_scholar/ai_text_processing/ai_text_processing.py
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
class TextProcessor(ABC):
    """Abstract base class for text processors that can return Pydantic objects."""

    @abstractmethod
    def process_text(
        self,
        text: str,
        instructions: str,
        response_format: Optional[Type[ResponseFormat]] = None,
        **kwargs,
    ) -> Union[str, ResponseFormat]:
        """
        Process text according to instructions.

        Args:
            text: Input text to process
            instructions: Processing instructions
            response_object: Optional Pydantic class for structured output
            **kwargs: Additional processing parameters

        Returns:
            Either string or Pydantic model instance based on response_model
        """
        pass
process_text(text, instructions, response_format=None, **kwargs) abstractmethod

Process text according to instructions.

Parameters:

Name Type Description Default
text str

Input text to process

required
instructions str

Processing instructions

required
response_object

Optional Pydantic class for structured output

required
**kwargs

Additional processing parameters

{}

Returns:

Type Description
Union[str, ResponseFormat]

Either string or Pydantic model instance based on response_model

Source code in src/tnh_scholar/ai_text_processing/ai_text_processing.py
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
@abstractmethod
def process_text(
    self,
    text: str,
    instructions: str,
    response_format: Optional[Type[ResponseFormat]] = None,
    **kwargs,
) -> Union[str, ResponseFormat]:
    """
    Process text according to instructions.

    Args:
        text: Input text to process
        instructions: Processing instructions
        response_object: Optional Pydantic class for structured output
        **kwargs: Additional processing parameters

    Returns:
        Either string or Pydantic model instance based on response_model
    """
    pass
TextPunctuator
Source code in src/tnh_scholar/ai_text_processing/ai_text_processing.py
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
class TextPunctuator:
    def __init__(
        self,
        processor: TextProcessor,
        punctuate_pattern: Pattern,
        source_language: Optional[str] = None,
        review_count: int = DEFAULT_REVIEW_COUNT,
        style_convention=DEFAULT_PUNCTUATE_STYLE,
    ):
        """
        Initialize punctuation generator.

        Args:
            text_punctuator: Implementation of TextProcessor
            punctuate_pattern: Pattern object containing punctuation instructions
            section_count: Target number of sections
            review_count: Number of review passes
        """

        self.source_language = source_language
        self.processor = processor
        self.punctuate_pattern = punctuate_pattern
        self.review_count = review_count
        self.style_convention = style_convention

    def punctuate_text(
        self,
        text: str,
        source_language: Optional[str] = None,
        template_dict: Optional[Dict] = None,
    ) -> str:
        """
        punctuate a text based on a pattern and source language.
        """

        if not source_language:
            if self.source_language:
                source_language = self.source_language
            else:
                source_language = get_language_name(text)

        template_values = {
            "source_language": source_language,
            "review_count": self.review_count,
            "style_convention": self.style_convention,
        }

        if template_dict:
            template_values |= template_dict

        logger.info("Punctuating text...")
        punctuate_instructions = self.punctuate_pattern.apply_template(template_values)
        text = self.processor.process_text(text, punctuate_instructions)
        logger.info("Punctuation completed.")

        # normalize newline spacing to two newline (default) between lines and return
        # commented out to allow pattern to dictate newlines.
        # return normalize_newlines(text)
        return text
processor = processor instance-attribute
punctuate_pattern = punctuate_pattern instance-attribute
review_count = review_count instance-attribute
source_language = source_language instance-attribute
style_convention = style_convention instance-attribute
__init__(processor, punctuate_pattern, source_language=None, review_count=DEFAULT_REVIEW_COUNT, style_convention=DEFAULT_PUNCTUATE_STYLE)

Initialize punctuation generator.

Parameters:

Name Type Description Default
text_punctuator

Implementation of TextProcessor

required
punctuate_pattern Pattern

Pattern object containing punctuation instructions

required
section_count

Target number of sections

required
review_count int

Number of review passes

DEFAULT_REVIEW_COUNT
Source code in src/tnh_scholar/ai_text_processing/ai_text_processing.py
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
def __init__(
    self,
    processor: TextProcessor,
    punctuate_pattern: Pattern,
    source_language: Optional[str] = None,
    review_count: int = DEFAULT_REVIEW_COUNT,
    style_convention=DEFAULT_PUNCTUATE_STYLE,
):
    """
    Initialize punctuation generator.

    Args:
        text_punctuator: Implementation of TextProcessor
        punctuate_pattern: Pattern object containing punctuation instructions
        section_count: Target number of sections
        review_count: Number of review passes
    """

    self.source_language = source_language
    self.processor = processor
    self.punctuate_pattern = punctuate_pattern
    self.review_count = review_count
    self.style_convention = style_convention
punctuate_text(text, source_language=None, template_dict=None)

punctuate a text based on a pattern and source language.

Source code in src/tnh_scholar/ai_text_processing/ai_text_processing.py
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
def punctuate_text(
    self,
    text: str,
    source_language: Optional[str] = None,
    template_dict: Optional[Dict] = None,
) -> str:
    """
    punctuate a text based on a pattern and source language.
    """

    if not source_language:
        if self.source_language:
            source_language = self.source_language
        else:
            source_language = get_language_name(text)

    template_values = {
        "source_language": source_language,
        "review_count": self.review_count,
        "style_convention": self.style_convention,
    }

    if template_dict:
        template_values |= template_dict

    logger.info("Punctuating text...")
    punctuate_instructions = self.punctuate_pattern.apply_template(template_values)
    text = self.processor.process_text(text, punctuate_instructions)
    logger.info("Punctuation completed.")

    # normalize newline spacing to two newline (default) between lines and return
    # commented out to allow pattern to dictate newlines.
    # return normalize_newlines(text)
    return text
find_sections(text, source_language=None, section_pattern=None, section_model=None, max_tokens=DEFAULT_SECTION_RESULT_MAX_SIZE, section_count=None, review_count=DEFAULT_REVIEW_COUNT, template_dict=None)

High-level function for generating text sections.

Parameters:

Name Type Description Default
text str

Input text

required
source_language Optional[str]

ISO 639-1 language code

None
pattern

Optional custom pattern (uses default if None)

required
model

Optional model identifier

required
max_tokens int

Maximum tokens for response

DEFAULT_SECTION_RESULT_MAX_SIZE
section_count Optional[int]

Target number of sections

None
review_count int

Number of review passes

DEFAULT_REVIEW_COUNT
template_dict Optional[Dict[str, str]]

Optional additional template variables

None

Returns:

Type Description
TextObject

TextObject containing section breakdown

Source code in src/tnh_scholar/ai_text_processing/ai_text_processing.py
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
def find_sections(
    text: str,
    source_language: Optional[str] = None,
    section_pattern: Optional[Pattern] = None,
    section_model: Optional[str] = None,
    max_tokens: int = DEFAULT_SECTION_RESULT_MAX_SIZE,
    section_count: Optional[int] = None,
    review_count: int = DEFAULT_REVIEW_COUNT,
    template_dict: Optional[Dict[str, str]] = None,
) -> TextObject:
    """
    High-level function for generating text sections.

    Args:
        text: Input text
        source_language: ISO 639-1 language code
        pattern: Optional custom pattern (uses default if None)
        model: Optional model identifier
        max_tokens: Maximum tokens for response
        section_count: Target number of sections
        review_count: Number of review passes
        template_dict: Optional additional template variables

    Returns:
        TextObject containing section breakdown
    """
    if section_pattern is None:
        section_pattern = get_default_pattern(DEFAULT_SECTION_PATTERN)
        logger.debug(f"Using default section pattern: {DEFAULT_SECTION_PATTERN}.")

    if source_language is None:
        source_language = get_language_name(text)

    section_scanner = OpenAIProcessor(model=section_model, max_tokens=max_tokens)
    parser = SectionParser(
        section_scanner=section_scanner,
        section_pattern=section_pattern,
        review_count=review_count,
    )

    return parser.find_sections(
        text,
        source_language=source_language,
        section_count_target=section_count,
        template_dict=template_dict,
    )
get_default_pattern(name)

Get a pattern by name using the singleton PatternManager.

This is a more efficient version that reuses a single PatternManager instance.

Parameters:

Name Type Description Default
name str

Name of the pattern to load

required

Returns:

Type Description
Pattern

The loaded pattern

Raises:

Type Description
ValueError

If pattern name is invalid

FileNotFoundError

If pattern file doesn't exist

Source code in src/tnh_scholar/ai_text_processing/ai_text_processing.py
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
def get_default_pattern(name: str) -> Pattern:
    """
    Get a pattern by name using the singleton PatternManager.

    This is a more efficient version that reuses a single PatternManager instance.

    Args:
        name: Name of the pattern to load

    Returns:
        The loaded pattern

    Raises:
        ValueError: If pattern name is invalid
        FileNotFoundError: If pattern file doesn't exist
    """
    return LocalPatternManager().pattern_manager.load_pattern(name)
process_text(text, pattern, source_language=None, model=None, template_dict=None)
Source code in src/tnh_scholar/ai_text_processing/ai_text_processing.py
 993
 994
 995
 996
 997
 998
 999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
def process_text(
    text: str,
    pattern: Pattern,
    source_language: Optional[str] = None,
    model: Optional[str] = None,
    template_dict: Optional[Dict] = None,
) -> str:

    if not model:
        model = DEFAULT_OPENAI_MODEL

    processor = GeneralProcessor(
        processor=OpenAIProcessor(model),
        source_language=source_language,
        pattern=pattern,
    )

    return processor.process_text(
        text, source_language=source_language, template_dict=template_dict
    )
process_text_by_paragraphs(transcript, template_dict, pattern=None, model=None)

High-level function for processing text paragraphs. Assumes paragraphs are separated by newlines. Uses DEFAULT_XML_FORMAT_PATTERN as default pattern for text processing.

Parameters:

Name Type Description Default
transcript str

Text to process

required
pattern Optional[Pattern]

Pattern object containing processing instructions

None
template_dict Dict[str, str]

Dictionary for template substitution

required
model Optional[str]

Optional model identifier for processor

None

Returns:

Type Description
None

Generator for ProcessedSections

Source code in src/tnh_scholar/ai_text_processing/ai_text_processing.py
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
def process_text_by_paragraphs(
    transcript: str,
    template_dict: Dict[str, str],
    pattern: Optional[Pattern] = None,
    model: Optional[str] = None,
) -> Generator[str, None, None]:
    """
    High-level function for processing text paragraphs. Assumes paragraphs are separated by newlines.
    Uses DEFAULT_XML_FORMAT_PATTERN as default pattern for text processing.

    Args:
        transcript: Text to process
        pattern: Pattern object containing processing instructions
        template_dict: Dictionary for template substitution
        model: Optional model identifier for processor


    Returns:
        Generator for ProcessedSections
    """
    processor = OpenAIProcessor(model)

    if not pattern:
        pattern = get_default_pattern(DEFAULT_PARAGRAPH_FORMAT_PATTERN)

    section_processor = SectionProcessor(processor, pattern, template_dict)

    return section_processor.process_paragraphs(transcript)
process_text_by_sections(transcript, text_object, template_dict, pattern=None, model=None)

High-level function for processing text sections with configurable output handling.

Parameters:

Name Type Description Default
transcript str

Text to process

required
text_object TextObject

Object containing section definitions

required
pattern Optional[Pattern]

Pattern object containing processing instructions

None
template_dict Dict

Dictionary for template substitution

required
model Optional[str]

Optional model identifier for processor

None

Returns:

Type Description
None

Generator for ProcessedSections

Source code in src/tnh_scholar/ai_text_processing/ai_text_processing.py
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
def process_text_by_sections(
    transcript: str,
    text_object: TextObject,
    template_dict: Dict,
    pattern: Optional[Pattern] = None,
    model: Optional[str] = None,
) -> Generator[ProcessedSection, None, None]:
    """
    High-level function for processing text sections with configurable output handling.

    Args:
        transcript: Text to process
        text_object: Object containing section definitions
        pattern: Pattern object containing processing instructions
        template_dict: Dictionary for template substitution
        model: Optional model identifier for processor

    Returns:
        Generator for ProcessedSections
    """
    processor = OpenAIProcessor(model)

    if not pattern:
        pattern = get_default_pattern(DEFAULT_XML_FORMAT_PATTERN)

    section_processor = SectionProcessor(processor, pattern, template_dict)

    return section_processor.process_sections(
        transcript,
        text_object,
    )
punctuate_text(text, source_language=None, punctuate_pattern=None, punctuate_model=None, template_dict=None)
Source code in src/tnh_scholar/ai_text_processing/ai_text_processing.py
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
def punctuate_text(
    text,
    source_language: Optional[str] = None,
    punctuate_pattern: Optional[Pattern] = None,
    punctuate_model: Optional[str] = None,
    template_dict: Optional[Dict] = None,
) -> str:

    if not punctuate_model:
        punctuate_model = DEFAULT_PUNCTUATE_MODEL

    if not punctuate_pattern:
        punctuate_pattern = get_default_pattern(DEFAULT_PUNCTUATE_PATTERN)

    punctuator = TextPunctuator(
        processor=OpenAIProcessor(punctuate_model),
        source_language=source_language,
        punctuate_pattern=punctuate_pattern,
    )

    return punctuator.punctuate_text(
        text, source_language=source_language, template_dict=template_dict
    )
translate_text_by_lines(text, source_language=None, target_language=None, pattern=None, model=None, style=None, segment_size=None, context_lines=None, review_count=None, template_dict=None)
Source code in src/tnh_scholar/ai_text_processing/ai_text_processing.py
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
def translate_text_by_lines(
    text,
    source_language: Optional[str] = None,
    target_language: Optional[str] = None,
    pattern: Optional[Pattern] = None,
    model: Optional[str] = None,
    style: Optional[str] = None,
    segment_size: Optional[int] = None,
    context_lines: Optional[int] = None,
    review_count: Optional[int] = None,
    template_dict: Optional[Dict] = None,
) -> str:

    if pattern is None:
        pattern = get_default_pattern(DEFAULT_TRANSLATION_PATTERN)

    translator = LineTranslator(
        processor=OpenAIProcessor(model),
        pattern=pattern,
        style=style or DEFAULT_TRANSLATE_STYLE,
        context_lines=context_lines or DEFAULT_TRANSLATE_CONTEXT_LINES,
        review_count=review_count or DEFAULT_REVIEW_COUNT,
    )

    return translator.translate_text(
        text,
        source_language=source_language,
        target_language=target_language,
        segment_size=segment_size,
        template_dict=template_dict,
    )

openai_process_interface

TOKEN_BUFFER = 500 module-attribute
logger = get_child_logger(__name__) module-attribute
openai_process_text(text_input, process_instructions, model=None, response_format=None, batch=False, max_tokens=0)

postprocessing a transcription.

Source code in src/tnh_scholar/ai_text_processing/openai_process_interface.py
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
def openai_process_text(
    text_input: str,
    process_instructions: str,
    model: Optional[str] = None,
    response_format: Optional[ResponseFormat] = None,
    batch: bool = False,
    max_tokens: int = 0,
) -> Union[ResponseFormat, str]:
    """postprocessing a transcription."""

    user_prompts = [text_input]
    system_message = process_instructions

    logger.debug(f"OpenAI Process Text with process instructions:\n{system_message}")
    if max_tokens == 0:
        tokens = token_count(text_input)
        max_tokens = tokens + TOKEN_BUFFER

    model_name = model or "default"

    logger.info(
        f"Open AI Text Processing{' as batch process' if batch else ''} with model '{model_name}' initiated. Requesting a maximum of {max_tokens} tokens."
    )

    if batch:
        return _run_batch_process_text(
            response_format, user_prompts, system_message, max_tokens
        )
    completion = run_immediate_completion_simple(
        system_message,
        text_input,
        max_tokens=max_tokens,
        response_format=response_format,
    )
    logger.debug(f"Full completion:\n{completion}")
    if response_format:
        process_object = get_completion_object(completion)
        logger.info("Processing completed.")
        return process_object
    else:
        process_text = get_completion_content(completion)
        logger.info("Processing completed.")
        return process_text

patterns

MarkdownStr = NewType('MarkdownStr', str) module-attribute
SYSTEM_UPDATE_MESSAGE = 'PatternManager System Update:' module-attribute
logger = get_child_logger(__name__) module-attribute
ConcurrentAccessManager

Manages concurrent access to pattern files.

Provides: - File-level locking - Safe concurrent access patterns - Lock cleanup

Source code in src/tnh_scholar/ai_text_processing/patterns.py
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
class ConcurrentAccessManager:
    """
    Manages concurrent access to pattern files.

    Provides:
    - File-level locking
    - Safe concurrent access patterns
    - Lock cleanup
    """

    def __init__(self, lock_dir: Path):
        """
        Initialize access manager.

        Args:
            lock_dir: Directory for lock files
        """
        self.lock_dir = Path(lock_dir)
        self._ensure_lock_dir()
        self._cleanup_stale_locks()

    def _ensure_lock_dir(self) -> None:
        """Create lock directory if it doesn't exist."""
        self.lock_dir.mkdir(parents=True, exist_ok=True)

    def _cleanup_stale_locks(self, max_age: timedelta = timedelta(hours=1)) -> None:
        """
        Remove stale lock files.

        Args:
            max_age: Maximum age for lock files before considered stale
        """
        current_time = datetime.now()
        for lock_file in self.lock_dir.glob("*.lock"):
            try:
                mtime = datetime.fromtimestamp(lock_file.stat().st_mtime)
                if current_time - mtime > max_age:
                    lock_file.unlink()
                    logger.warning(f"Removed stale lock file: {lock_file}")
            except FileNotFoundError:
                # Lock was removed by another process
                pass
            except Exception as e:
                logger.error(f"Error cleaning up lock file {lock_file}: {e}")

    @contextmanager
    def file_lock(self, file_path: Path):
        """
        Context manager for safely accessing files.

        Args:
            file_path: Path to file to lock

        Yields:
            None when lock is acquired

        Raises:
            RuntimeError: If file is already locked
            OSError: If lock file operations fail
        """
        file_path = Path(file_path)
        lock_file_path = self.lock_dir / f"{file_path.stem}.lock"
        lock_fd = None

        try:
            # Open or create lock file
            lock_fd = os.open(str(lock_file_path), os.O_WRONLY | os.O_CREAT)

            try:
                # Attempt to acquire lock
                fcntl.flock(lock_fd, fcntl.LOCK_EX | fcntl.LOCK_NB)

                # Write process info to lock file
                pid = os.getpid()
                timestamp = datetime.now().isoformat()
                os.write(lock_fd, f"{pid} {timestamp}\n".encode())

                logger.debug(f"Acquired lock for {file_path}")
                yield

            except BlockingIOError as e:
                raise RuntimeError(
                    f"File {file_path} is locked by another process"
                ) from e

        except OSError as e:
            logger.error(f"Lock operation failed for {file_path}: {e}")
            raise

        finally:
            if lock_fd is not None:
                try:
                    # Release lock and close file descriptor
                    fcntl.flock(lock_fd, fcntl.LOCK_UN)
                    os.close(lock_fd)

                    # Remove lock file
                    lock_file_path.unlink(missing_ok=True)
                    logger.debug(f"Released lock for {file_path}")

                except Exception as e:
                    logger.error(f"Error cleaning up lock for {file_path}: {e}")

    def is_locked(self, file_path: Path) -> bool:
        """
        Check if a file is currently locked.

        Args:
            file_path: Path to file to check

        Returns:
            bool: True if file is locked
        """
        lock_file_path = self.lock_dir / f"{file_path.stem}.lock"

        if not lock_file_path.exists():
            return False

        try:
            with open(lock_file_path, "r") as f:
                # Try to acquire and immediately release lock
                fcntl.flock(f, fcntl.LOCK_EX | fcntl.LOCK_NB)
                fcntl.flock(f, fcntl.LOCK_UN)
                return False
        except BlockingIOError:
            return True
        except Exception:
            return False
lock_dir = Path(lock_dir) instance-attribute
__init__(lock_dir)

Initialize access manager.

Parameters:

Name Type Description Default
lock_dir Path

Directory for lock files

required
Source code in src/tnh_scholar/ai_text_processing/patterns.py
518
519
520
521
522
523
524
525
526
527
def __init__(self, lock_dir: Path):
    """
    Initialize access manager.

    Args:
        lock_dir: Directory for lock files
    """
    self.lock_dir = Path(lock_dir)
    self._ensure_lock_dir()
    self._cleanup_stale_locks()
file_lock(file_path)

Context manager for safely accessing files.

Parameters:

Name Type Description Default
file_path Path

Path to file to lock

required

Yields:

Type Description

None when lock is acquired

Raises:

Type Description
RuntimeError

If file is already locked

OSError

If lock file operations fail

Source code in src/tnh_scholar/ai_text_processing/patterns.py
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
@contextmanager
def file_lock(self, file_path: Path):
    """
    Context manager for safely accessing files.

    Args:
        file_path: Path to file to lock

    Yields:
        None when lock is acquired

    Raises:
        RuntimeError: If file is already locked
        OSError: If lock file operations fail
    """
    file_path = Path(file_path)
    lock_file_path = self.lock_dir / f"{file_path.stem}.lock"
    lock_fd = None

    try:
        # Open or create lock file
        lock_fd = os.open(str(lock_file_path), os.O_WRONLY | os.O_CREAT)

        try:
            # Attempt to acquire lock
            fcntl.flock(lock_fd, fcntl.LOCK_EX | fcntl.LOCK_NB)

            # Write process info to lock file
            pid = os.getpid()
            timestamp = datetime.now().isoformat()
            os.write(lock_fd, f"{pid} {timestamp}\n".encode())

            logger.debug(f"Acquired lock for {file_path}")
            yield

        except BlockingIOError as e:
            raise RuntimeError(
                f"File {file_path} is locked by another process"
            ) from e

    except OSError as e:
        logger.error(f"Lock operation failed for {file_path}: {e}")
        raise

    finally:
        if lock_fd is not None:
            try:
                # Release lock and close file descriptor
                fcntl.flock(lock_fd, fcntl.LOCK_UN)
                os.close(lock_fd)

                # Remove lock file
                lock_file_path.unlink(missing_ok=True)
                logger.debug(f"Released lock for {file_path}")

            except Exception as e:
                logger.error(f"Error cleaning up lock for {file_path}: {e}")
is_locked(file_path)

Check if a file is currently locked.

Parameters:

Name Type Description Default
file_path Path

Path to file to check

required

Returns:

Name Type Description
bool bool

True if file is locked

Source code in src/tnh_scholar/ai_text_processing/patterns.py
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
def is_locked(self, file_path: Path) -> bool:
    """
    Check if a file is currently locked.

    Args:
        file_path: Path to file to check

    Returns:
        bool: True if file is locked
    """
    lock_file_path = self.lock_dir / f"{file_path.stem}.lock"

    if not lock_file_path.exists():
        return False

    try:
        with open(lock_file_path, "r") as f:
            # Try to acquire and immediately release lock
            fcntl.flock(f, fcntl.LOCK_EX | fcntl.LOCK_NB)
            fcntl.flock(f, fcntl.LOCK_UN)
            return False
    except BlockingIOError:
        return True
    except Exception:
        return False
GitBackedRepository

Manages versioned storage of patterns using Git.

Provides basic Git operations while hiding complexity: - Automatic versioning of changes - Basic conflict resolution - History tracking

Source code in src/tnh_scholar/ai_text_processing/patterns.py
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
class GitBackedRepository:
    """
    Manages versioned storage of patterns using Git.

    Provides basic Git operations while hiding complexity:
    - Automatic versioning of changes
    - Basic conflict resolution
    - History tracking
    """

    def __init__(self, repo_path: Path):
        """
        Initialize or connect to Git repository.

        Args:
            repo_path: Path to repository directory

        Raises:
            GitCommandError: If Git operations fail
        """
        self.repo_path = repo_path

        try:
            # Try to connect to existing repository
            self.repo = Repo(repo_path)
            logger.debug(f"Connected to existing Git repository at {repo_path}")

        except InvalidGitRepositoryError:
            # Initialize new repository if none exists
            logger.info(f"Initializing new Git repository at {repo_path}")
            self.repo = Repo.init(repo_path)

            # Create initial commit if repo is empty
            if not self.repo.head.is_valid():
                # Create and commit .gitignore
                gitignore = repo_path / ".gitignore"
                gitignore.write_text("*.lock\n.DS_Store\n")
                self.repo.index.add([".gitignore"])
                self.repo.index.commit("Initial repository setup")

    def update_file(self, file_path: Path) -> str:
        """
        Stage and commit changes to a file in the Git repository.

        Args:
            file_path: Absolute or relative path to the file.

        Returns:
            str: Commit hash if changes were made.

        Raises:
            FileNotFoundError: If the file does not exist.
            ValueError: If the file is outside the repository.
            GitCommandError: If Git operations fail.
        """
        file_path = file_path.resolve()

        # Ensure the file is within the repository
        try:
            rel_path = file_path.relative_to(self.repo_path)
        except ValueError as e:
            raise ValueError(
                f"File {file_path} is not under the repository root {self.repo_path}"
            ) from e

        if not file_path.exists():
            raise FileNotFoundError(f"File does not exist: {file_path}")

        try:
            return self._commit_file_update(rel_path, file_path)
        except GitCommandError as e:
            logger.error(f"Git operation failed: {e}")
            raise

    def _commit_file_update(self, rel_path, file_path):
        if self._is_file_clean(rel_path):
            # Return the current commit hash if no changes
            return self.repo.head.commit.hexsha

        logger.info(f"Detected changes in {rel_path}, updating version control.")
        self.repo.index.add([str(rel_path)])
        commit = self.repo.index.commit(
            f"{SYSTEM_UPDATE_MESSAGE} {rel_path.stem}",
            author=Actor("PatternManager", ""),
        )
        logger.info(f"Committed changes to {file_path}: {commit.hexsha}")
        return commit.hexsha

    def _get_file_revisions(self, file_path: Path) -> List[Commit]:
        """
        Get ordered list of commits that modified a file, most recent first.

        Args:
            file_path: Path to file relative to repository root

        Returns:
            List of Commit objects affecting this file

        Raises:
            GitCommandError: If Git operations fail
        """
        rel_path = file_path.relative_to(self.repo_path)
        try:
            return list(self.repo.iter_commits(paths=str(rel_path)))
        except GitCommandError as e:
            logger.error(f"Failed to get commits for {rel_path}: {e}")
            return []

    def _get_commit_diff(
        self, commit: Commit, file_path: Path, prev_commit: Optional[Commit] = None
    ) -> Tuple[str, str]:
        """
        Get both stat and detailed diff for a commit.

        Args:
            commit: Commit to diff
            file_path: Path relative to repository root
            prev_commit: Previous commit for diff, defaults to commit's parent

        Returns:
            Tuple of (stat_diff, detailed_diff) where:
                stat_diff: Summary of changes (files changed, insertions/deletions)
                detailed_diff: Colored word-level diff with context

        Raises:
            GitCommandError: If Git operations fail
        """
        prev_hash = prev_commit.hexsha if prev_commit else f"{commit.hexsha}^"
        rel_path = file_path.relative_to(self.repo_path)

        try:
            # Get stats diff
            stat = self.repo.git.diff(prev_hash, commit.hexsha, rel_path, stat=True)

            # Get detailed diff
            diff = self.repo.git.diff(
                prev_hash,
                commit.hexsha,
                rel_path,
                unified=2,
                word_diff="plain",
                color="always",
                ignore_space_change=True,
            )

            return stat, diff
        except GitCommandError as e:
            logger.error(f"Failed to get diff for {commit.hexsha}: {e}")
            return "", ""

    def display_history(self, file_path: Path, max_versions: int = 0) -> None:
        """
        Display history of changes for a file with diffs between versions.

        Shows most recent changes first, limited to max_versions entries.
        For each change shows:
        - Commit info and date
        - Stats summary of changes
        - Detailed color diff with 2 lines of context

        Args:
            file_path: Path to file in repository
            max_versions: Maximum number of versions to show, if zero, shows all revisions.

        Example:
            >>> repo.display_history(Path("patterns/format_dharma_talk.yaml"))
            Commit abc123def (2024-12-28 14:30:22):
            1 file changed, 5 insertions(+), 2 deletions(-)

            diff --git a/patterns/format_dharma_talk.yaml ...
            ...
        """

        try:
            # Get commit history
            commits = self._get_file_revisions(file_path)
            if not commits:
                print(f"No history found for {file_path}")
                return

            if max_versions == 0:
                max_versions = len(commits)  # look at all commits.

            # Display limited history with diffs
            for i, commit in enumerate(commits[:max_versions]):
                # Print commit header
                date_str = commit.committed_datetime.strftime("%Y-%m-%d %H:%M:%S")
                print(f"\nCommit {commit.hexsha[:8]} ({date_str}):")
                print(f"Message: {commit.message.strip()}")

                # Get and display diffs
                prev_commit = commits[i + 1] if i + 1 < len(commits) else None
                stat_diff, detailed_diff = self._get_commit_diff(
                    commit, file_path, prev_commit
                )

                if stat_diff:
                    print("\nChanges:")
                    print(stat_diff)
                if detailed_diff:
                    print("\nDetailed diff:")
                    print(detailed_diff)

                print("\033[0m", end="")
                print("-" * 80)  # Visual separator between commits

        except Exception as e:
            logger.error(f"Failed to display history for {file_path}: {e}")
            print(f"Error displaying history: {e}")
            raise

    def _is_file_clean(self, rel_path: Path) -> bool:
        """
        Check if file has uncommitted changes.

        Args:
            rel_path: Path relative to repository root

        Returns:
            bool: True if file has no changes
        """
        return str(rel_path) not in (
            [item.a_path for item in self.repo.index.diff(None)]
            + self.repo.untracked_files
        )
repo = Repo(repo_path) instance-attribute
repo_path = repo_path instance-attribute
__init__(repo_path)

Initialize or connect to Git repository.

Parameters:

Name Type Description Default
repo_path Path

Path to repository directory

required

Raises:

Type Description
GitCommandError

If Git operations fail

Source code in src/tnh_scholar/ai_text_processing/patterns.py
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
def __init__(self, repo_path: Path):
    """
    Initialize or connect to Git repository.

    Args:
        repo_path: Path to repository directory

    Raises:
        GitCommandError: If Git operations fail
    """
    self.repo_path = repo_path

    try:
        # Try to connect to existing repository
        self.repo = Repo(repo_path)
        logger.debug(f"Connected to existing Git repository at {repo_path}")

    except InvalidGitRepositoryError:
        # Initialize new repository if none exists
        logger.info(f"Initializing new Git repository at {repo_path}")
        self.repo = Repo.init(repo_path)

        # Create initial commit if repo is empty
        if not self.repo.head.is_valid():
            # Create and commit .gitignore
            gitignore = repo_path / ".gitignore"
            gitignore.write_text("*.lock\n.DS_Store\n")
            self.repo.index.add([".gitignore"])
            self.repo.index.commit("Initial repository setup")
display_history(file_path, max_versions=0)

Display history of changes for a file with diffs between versions.

Shows most recent changes first, limited to max_versions entries. For each change shows: - Commit info and date - Stats summary of changes - Detailed color diff with 2 lines of context

Parameters:

Name Type Description Default
file_path Path

Path to file in repository

required
max_versions int

Maximum number of versions to show, if zero, shows all revisions.

0
Example

repo.display_history(Path("patterns/format_dharma_talk.yaml")) Commit abc123def (2024-12-28 14:30:22): 1 file changed, 5 insertions(+), 2 deletions(-)

diff --git a/patterns/format_dharma_talk.yaml ... ...

Source code in src/tnh_scholar/ai_text_processing/patterns.py
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
def display_history(self, file_path: Path, max_versions: int = 0) -> None:
    """
    Display history of changes for a file with diffs between versions.

    Shows most recent changes first, limited to max_versions entries.
    For each change shows:
    - Commit info and date
    - Stats summary of changes
    - Detailed color diff with 2 lines of context

    Args:
        file_path: Path to file in repository
        max_versions: Maximum number of versions to show, if zero, shows all revisions.

    Example:
        >>> repo.display_history(Path("patterns/format_dharma_talk.yaml"))
        Commit abc123def (2024-12-28 14:30:22):
        1 file changed, 5 insertions(+), 2 deletions(-)

        diff --git a/patterns/format_dharma_talk.yaml ...
        ...
    """

    try:
        # Get commit history
        commits = self._get_file_revisions(file_path)
        if not commits:
            print(f"No history found for {file_path}")
            return

        if max_versions == 0:
            max_versions = len(commits)  # look at all commits.

        # Display limited history with diffs
        for i, commit in enumerate(commits[:max_versions]):
            # Print commit header
            date_str = commit.committed_datetime.strftime("%Y-%m-%d %H:%M:%S")
            print(f"\nCommit {commit.hexsha[:8]} ({date_str}):")
            print(f"Message: {commit.message.strip()}")

            # Get and display diffs
            prev_commit = commits[i + 1] if i + 1 < len(commits) else None
            stat_diff, detailed_diff = self._get_commit_diff(
                commit, file_path, prev_commit
            )

            if stat_diff:
                print("\nChanges:")
                print(stat_diff)
            if detailed_diff:
                print("\nDetailed diff:")
                print(detailed_diff)

            print("\033[0m", end="")
            print("-" * 80)  # Visual separator between commits

    except Exception as e:
        logger.error(f"Failed to display history for {file_path}: {e}")
        print(f"Error displaying history: {e}")
        raise
update_file(file_path)

Stage and commit changes to a file in the Git repository.

Parameters:

Name Type Description Default
file_path Path

Absolute or relative path to the file.

required

Returns:

Name Type Description
str str

Commit hash if changes were made.

Raises:

Type Description
FileNotFoundError

If the file does not exist.

ValueError

If the file is outside the repository.

GitCommandError

If Git operations fail.

Source code in src/tnh_scholar/ai_text_processing/patterns.py
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
def update_file(self, file_path: Path) -> str:
    """
    Stage and commit changes to a file in the Git repository.

    Args:
        file_path: Absolute or relative path to the file.

    Returns:
        str: Commit hash if changes were made.

    Raises:
        FileNotFoundError: If the file does not exist.
        ValueError: If the file is outside the repository.
        GitCommandError: If Git operations fail.
    """
    file_path = file_path.resolve()

    # Ensure the file is within the repository
    try:
        rel_path = file_path.relative_to(self.repo_path)
    except ValueError as e:
        raise ValueError(
            f"File {file_path} is not under the repository root {self.repo_path}"
        ) from e

    if not file_path.exists():
        raise FileNotFoundError(f"File does not exist: {file_path}")

    try:
        return self._commit_file_update(rel_path, file_path)
    except GitCommandError as e:
        logger.error(f"Git operation failed: {e}")
        raise
Pattern

Base Pattern class for version-controlled template patterns.

Patterns contain: - Instructions: The main pattern instructions as a Jinja2 template. Note: Instructions are intended to be saved in markdown format in a .md file. - Template fields: Default values for template variables - Metadata: Name and identifier information

Version control is handled externally through Git, not in the pattern itself. Pattern identity is determined by the combination of identifiers.

Attributes:

Name Type Description
name str

The name of the pattern

instructions str

The Jinja2 template string for this pattern

default_template_fields Dict[str, str]

Default values for template variables

_allow_empty_vars bool

Whether to allow undefined template variables

_env Environment

Configured Jinja2 environment instance

Source code in src/tnh_scholar/ai_text_processing/patterns.py
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
class Pattern:
    """
    Base Pattern class for version-controlled template patterns.

    Patterns contain:
    - Instructions: The main pattern instructions as a Jinja2 template.
       Note: Instructions are intended to be saved in markdown format in a .md file.
    - Template fields: Default values for template variables
    - Metadata: Name and identifier information

    Version control is handled externally through Git, not in the pattern itself.
    Pattern identity is determined by the combination of identifiers.

    Attributes:
        name (str): The name of the pattern
        instructions (str): The Jinja2 template string for this pattern
        default_template_fields (Dict[str, str]): Default values for template variables
        _allow_empty_vars (bool): Whether to allow undefined template variables
        _env (Environment): Configured Jinja2 environment instance
    """

    def __init__(
        self,
        name: str,
        instructions: MarkdownStr,
        default_template_fields: Optional[Dict[str, str]] = None,
        allow_empty_vars: bool = False,
    ) -> None:
        """
        Initialize a new Pattern instance.

        Args:
            name: Unique name identifying the pattern
            instructions: Jinja2 template string containing the pattern
            default_template_fields: Optional default values for template variables
            allow_empty_vars: Whether to allow undefined template variables

        Raises:
            ValueError: If name or instructions are empty
            TemplateError: If template syntax is invalid
        """
        if not name or not instructions:
            raise ValueError("Name and instructions must not be empty")

        self.name = name
        self.instructions = instructions
        self.default_template_fields = default_template_fields or {}
        self._allow_empty_vars = allow_empty_vars
        self._env = self._create_environment()

        # Validate template syntax on initialization
        self._validate_template()

    @staticmethod
    def _create_environment() -> Environment:
        """
        Create and configure a Jinja2 environment with optimal settings.

        Returns:
            Environment: Configured Jinja2 environment with security and formatting options
        """
        return Environment(
            undefined=StrictUndefined,  # Raise errors for undefined variables
            trim_blocks=True,  # Remove first newline after a block
            lstrip_blocks=True,  # Strip tabs and spaces from the start of lines
            autoescape=True,  # Enable autoescaping for security
        )

    def _validate_template(self) -> None:
        """
        Validate the template syntax without rendering.

        Raises:
            TemplateError: If template syntax is invalid
        """
        try:
            self._env.parse(self.instructions)
        except TemplateError as e:
            raise TemplateError(
                f"Invalid template syntax in pattern '{self.name}': {str(e)}"
            ) from e

    @lru_cache(maxsize=128)
    def _get_template(self) -> Template:
        """
        Get or create a cached template instance.

        Returns:
            Template: Compiled Jinja2 template
        """
        return self._env.from_string(self.instructions)

    def apply_template(self, field_values: Optional[Dict[str, str]] = None) -> str:
        """
        Apply template values to pattern instructions using Jinja2.

        Args:
            field_values: Values to substitute into the template.
                        If None, default_template_fields are used.

        Returns:
            str: Rendered instructions with template values applied.

        Raises:
            TemplateError: If template rendering fails
            ValueError: If required template variables are missing
        """
        # Combine default fields with provided fields, with provided taking precedence
        template_values = {**self.default_template_fields, **(field_values or {})}

        instructions = self.get_content_without_frontmatter()

        try:
            return self._render_template_with_values(instructions, template_values)
        except TemplateError as e:
            raise TemplateError(
                f"Template rendering failed for pattern '{self.name}': {str(e)}"
            ) from e

    def _render_template_with_values(self, instructions, template_values):
        # Parse template to find required variables

        parsed_content = self._env.parse(instructions)
        required_vars = find_undeclared_variables(parsed_content)

        # Validate all required variables are provided
        missing_vars = required_vars - set(template_values.keys())
        if missing_vars and not self._allow_empty_vars:
            raise ValueError(
                f"Missing required template variables in pattern '{self.name}': "
                f"{', '.join(sorted(missing_vars))}"
            )

        template = self._get_template()
        return template.render(**template_values)

    def extract_frontmatter(self) -> Optional[Dict[str, Any]]:
        """
        Extract and validate YAML frontmatter from markdown instructions.

        Returns:
            Optional[Dict]: Frontmatter data if found and valid, None otherwise

        Note:
            Frontmatter must be at the very start of the file and properly formatted.
        """

        # More precise pattern matching
        pattern = r"\A---\s*\n(.*?)\n---\s*\n"
        if match := re.match(pattern, self.instructions, re.DOTALL):
            try:
                frontmatter = yaml.safe_load(match[1])
                if not isinstance(frontmatter, dict):
                    logger.warning("Frontmatter must be a YAML dictionary")
                    return None
                return frontmatter
            except yaml.YAMLError as e:
                logger.warning(f"Invalid YAML in frontmatter: {e}")
                return None
        return None

    def get_content_without_frontmatter(self) -> str:
        """
        Get markdown content with frontmatter removed.

        Returns:
            str: Markdown content without frontmatter
        """
        pattern = r"\A---\s*\n.*?\n---\s*\n"
        return re.sub(pattern, "", self.instructions, flags=re.DOTALL)

    def update_frontmatter(self, new_data: Dict[str, Any]) -> None:
        """
        Update or add frontmatter to the markdown content.

        Args:
            new_data: Dictionary of frontmatter fields to update
        """

        current_frontmatter = self.extract_frontmatter() or {}
        updated_frontmatter = {**current_frontmatter, **new_data}

        # Create YAML string
        yaml_str = yaml.dump(
            updated_frontmatter, default_flow_style=False, allow_unicode=True
        )

        # Remove existing frontmatter if present
        content = self.get_content_without_frontmatter()

        # Combine new frontmatter with content
        self.instructions = f"---\n{yaml_str}---\n\n{content}"

    def content_hash(self) -> str:
        """
        Generate a SHA-256 hash of the pattern content.

        Useful for quick content comparison and change detection.

        Returns:
            str: Hexadecimal string of the SHA-256 hash
        """
        content = f"{self.name}{self.instructions}{sorted(self.default_template_fields.items())}"
        return hashlib.sha256(content.encode("utf-8")).hexdigest()

    def to_dict(self) -> Dict[str, Any]:
        """
        Convert pattern to dictionary for serialization.

        Returns:
            Dict containing all pattern data in serializable format
        """
        return {
            "name": self.name,
            "instructions": self.instructions,
            "default_template_fields": self.default_template_fields,
        }

    @classmethod
    def from_dict(cls, data: Dict[str, Any]) -> "Pattern":
        """
        Create pattern instance from dictionary data.

        Args:
            data: Dictionary containing pattern data

        Returns:
            Pattern: New pattern instance

        Raises:
            ValueError: If required fields are missing
        """
        required_fields = {"name", "instructions"}
        if missing_fields := required_fields - set(data.keys()):
            raise ValueError(f"Missing required fields: {', '.join(missing_fields)}")

        return cls(
            name=data["name"],
            instructions=data["instructions"],
            default_template_fields=data.get("default_template_fields", {}),
        )

    def __eq__(self, other: object) -> bool:
        """Compare patterns based on their content."""
        if not isinstance(other, Pattern):
            return NotImplemented
        return self.content_hash() == other.content_hash()

    def __hash__(self) -> int:
        """Hash based on content hash for container operations."""
        return hash(self.content_hash())
default_template_fields = default_template_fields or {} instance-attribute
instructions = instructions instance-attribute
name = name instance-attribute
__eq__(other)

Compare patterns based on their content.

Source code in src/tnh_scholar/ai_text_processing/patterns.py
270
271
272
273
274
def __eq__(self, other: object) -> bool:
    """Compare patterns based on their content."""
    if not isinstance(other, Pattern):
        return NotImplemented
    return self.content_hash() == other.content_hash()
__hash__()

Hash based on content hash for container operations.

Source code in src/tnh_scholar/ai_text_processing/patterns.py
276
277
278
def __hash__(self) -> int:
    """Hash based on content hash for container operations."""
    return hash(self.content_hash())
__init__(name, instructions, default_template_fields=None, allow_empty_vars=False)

Initialize a new Pattern instance.

Parameters:

Name Type Description Default
name str

Unique name identifying the pattern

required
instructions MarkdownStr

Jinja2 template string containing the pattern

required
default_template_fields Optional[Dict[str, str]]

Optional default values for template variables

None
allow_empty_vars bool

Whether to allow undefined template variables

False

Raises:

Type Description
ValueError

If name or instructions are empty

TemplateError

If template syntax is invalid

Source code in src/tnh_scholar/ai_text_processing/patterns.py
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
def __init__(
    self,
    name: str,
    instructions: MarkdownStr,
    default_template_fields: Optional[Dict[str, str]] = None,
    allow_empty_vars: bool = False,
) -> None:
    """
    Initialize a new Pattern instance.

    Args:
        name: Unique name identifying the pattern
        instructions: Jinja2 template string containing the pattern
        default_template_fields: Optional default values for template variables
        allow_empty_vars: Whether to allow undefined template variables

    Raises:
        ValueError: If name or instructions are empty
        TemplateError: If template syntax is invalid
    """
    if not name or not instructions:
        raise ValueError("Name and instructions must not be empty")

    self.name = name
    self.instructions = instructions
    self.default_template_fields = default_template_fields or {}
    self._allow_empty_vars = allow_empty_vars
    self._env = self._create_environment()

    # Validate template syntax on initialization
    self._validate_template()
apply_template(field_values=None)

Apply template values to pattern instructions using Jinja2.

Parameters:

Name Type Description Default
field_values Optional[Dict[str, str]]

Values to substitute into the template. If None, default_template_fields are used.

None

Returns:

Name Type Description
str str

Rendered instructions with template values applied.

Raises:

Type Description
TemplateError

If template rendering fails

ValueError

If required template variables are missing

Source code in src/tnh_scholar/ai_text_processing/patterns.py
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
def apply_template(self, field_values: Optional[Dict[str, str]] = None) -> str:
    """
    Apply template values to pattern instructions using Jinja2.

    Args:
        field_values: Values to substitute into the template.
                    If None, default_template_fields are used.

    Returns:
        str: Rendered instructions with template values applied.

    Raises:
        TemplateError: If template rendering fails
        ValueError: If required template variables are missing
    """
    # Combine default fields with provided fields, with provided taking precedence
    template_values = {**self.default_template_fields, **(field_values or {})}

    instructions = self.get_content_without_frontmatter()

    try:
        return self._render_template_with_values(instructions, template_values)
    except TemplateError as e:
        raise TemplateError(
            f"Template rendering failed for pattern '{self.name}': {str(e)}"
        ) from e
content_hash()

Generate a SHA-256 hash of the pattern content.

Useful for quick content comparison and change detection.

Returns:

Name Type Description
str str

Hexadecimal string of the SHA-256 hash

Source code in src/tnh_scholar/ai_text_processing/patterns.py
221
222
223
224
225
226
227
228
229
230
231
def content_hash(self) -> str:
    """
    Generate a SHA-256 hash of the pattern content.

    Useful for quick content comparison and change detection.

    Returns:
        str: Hexadecimal string of the SHA-256 hash
    """
    content = f"{self.name}{self.instructions}{sorted(self.default_template_fields.items())}"
    return hashlib.sha256(content.encode("utf-8")).hexdigest()
extract_frontmatter()

Extract and validate YAML frontmatter from markdown instructions.

Returns:

Type Description
Optional[Dict[str, Any]]

Optional[Dict]: Frontmatter data if found and valid, None otherwise

Note

Frontmatter must be at the very start of the file and properly formatted.

Source code in src/tnh_scholar/ai_text_processing/patterns.py
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
def extract_frontmatter(self) -> Optional[Dict[str, Any]]:
    """
    Extract and validate YAML frontmatter from markdown instructions.

    Returns:
        Optional[Dict]: Frontmatter data if found and valid, None otherwise

    Note:
        Frontmatter must be at the very start of the file and properly formatted.
    """

    # More precise pattern matching
    pattern = r"\A---\s*\n(.*?)\n---\s*\n"
    if match := re.match(pattern, self.instructions, re.DOTALL):
        try:
            frontmatter = yaml.safe_load(match[1])
            if not isinstance(frontmatter, dict):
                logger.warning("Frontmatter must be a YAML dictionary")
                return None
            return frontmatter
        except yaml.YAMLError as e:
            logger.warning(f"Invalid YAML in frontmatter: {e}")
            return None
    return None
from_dict(data) classmethod

Create pattern instance from dictionary data.

Parameters:

Name Type Description Default
data Dict[str, Any]

Dictionary containing pattern data

required

Returns:

Name Type Description
Pattern Pattern

New pattern instance

Raises:

Type Description
ValueError

If required fields are missing

Source code in src/tnh_scholar/ai_text_processing/patterns.py
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
@classmethod
def from_dict(cls, data: Dict[str, Any]) -> "Pattern":
    """
    Create pattern instance from dictionary data.

    Args:
        data: Dictionary containing pattern data

    Returns:
        Pattern: New pattern instance

    Raises:
        ValueError: If required fields are missing
    """
    required_fields = {"name", "instructions"}
    if missing_fields := required_fields - set(data.keys()):
        raise ValueError(f"Missing required fields: {', '.join(missing_fields)}")

    return cls(
        name=data["name"],
        instructions=data["instructions"],
        default_template_fields=data.get("default_template_fields", {}),
    )
get_content_without_frontmatter()

Get markdown content with frontmatter removed.

Returns:

Name Type Description
str str

Markdown content without frontmatter

Source code in src/tnh_scholar/ai_text_processing/patterns.py
189
190
191
192
193
194
195
196
197
def get_content_without_frontmatter(self) -> str:
    """
    Get markdown content with frontmatter removed.

    Returns:
        str: Markdown content without frontmatter
    """
    pattern = r"\A---\s*\n.*?\n---\s*\n"
    return re.sub(pattern, "", self.instructions, flags=re.DOTALL)
to_dict()

Convert pattern to dictionary for serialization.

Returns:

Type Description
Dict[str, Any]

Dict containing all pattern data in serializable format

Source code in src/tnh_scholar/ai_text_processing/patterns.py
233
234
235
236
237
238
239
240
241
242
243
244
def to_dict(self) -> Dict[str, Any]:
    """
    Convert pattern to dictionary for serialization.

    Returns:
        Dict containing all pattern data in serializable format
    """
    return {
        "name": self.name,
        "instructions": self.instructions,
        "default_template_fields": self.default_template_fields,
    }
update_frontmatter(new_data)

Update or add frontmatter to the markdown content.

Parameters:

Name Type Description Default
new_data Dict[str, Any]

Dictionary of frontmatter fields to update

required
Source code in src/tnh_scholar/ai_text_processing/patterns.py
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
def update_frontmatter(self, new_data: Dict[str, Any]) -> None:
    """
    Update or add frontmatter to the markdown content.

    Args:
        new_data: Dictionary of frontmatter fields to update
    """

    current_frontmatter = self.extract_frontmatter() or {}
    updated_frontmatter = {**current_frontmatter, **new_data}

    # Create YAML string
    yaml_str = yaml.dump(
        updated_frontmatter, default_flow_style=False, allow_unicode=True
    )

    # Remove existing frontmatter if present
    content = self.get_content_without_frontmatter()

    # Combine new frontmatter with content
    self.instructions = f"---\n{yaml_str}---\n\n{content}"
PatternManager

Main interface for pattern management system.

Provides high-level operations: - Pattern creation and loading - Automatic versioning - Safe concurrent access - Basic history tracking

Source code in src/tnh_scholar/ai_text_processing/patterns.py
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
class PatternManager:
    """
    Main interface for pattern management system.

    Provides high-level operations:
    - Pattern creation and loading
    - Automatic versioning
    - Safe concurrent access
    - Basic history tracking
    """

    def __init__(self, base_path: Path):
        """
        Initialize pattern management system.

        Args:
            base_path: Base directory for pattern storage
        """
        self.base_path = Path(base_path).resolve()
        self.base_path.mkdir(parents=True, exist_ok=True)

        # Initialize subsystems
        self.repo = GitBackedRepository(self.base_path)
        self.access_manager = ConcurrentAccessManager(self.base_path / ".locks")

        logger.info(f"Initialized pattern management system at {base_path}")

    def _normalize_path(self, path: Union[str, Path]) -> Path:
        """
        Normalize a path to be absolute under the repository base path.

        Handles these cases to same result:
        - "my_file" -> <base_path>/my_file
        - "<base_path>/my_file" -> <base_path>/my_file

        Args:
            path: Input path as string or Path

        Returns:
            Path: Absolute path under base_path

        Raises:
            ValueError: If path would resolve outside repository
        """
        path = Path(path)  # ensure we have a path

        # Join with base_path as needed
        if not path.is_absolute():
            path = path if path.parent == self.base_path else self.base_path / path

        # Safety check after resolution
        resolved = path.resolve()
        if not resolved.is_relative_to(self.base_path):
            raise ValueError(
                f"Path {path} resolves outside repository: {self.base_path}"
            )

        return resolved

    def get_pattern_path(self, pattern_name: str) -> Optional[Path]:
        """
        Recursively search for a pattern file with the given name in base_path and all subdirectories.

        Args:
            pattern_id: pattern identifier to search for

        Returns:
            Optional[Path]: Full path to the found pattern file, or None if not found
        """
        pattern = f"{pattern_name}.md"

        try:
            pattern_path = next(
                path for path in self.base_path.rglob(pattern) if path.is_file()
            )
            logger.debug(f"Found pattern file for ID {pattern_name} at: {pattern_path}")
            return self._normalize_path(pattern_path)

        except StopIteration:
            logger.debug(f"No pattern file found with ID: {pattern_name}")
            return None

    def save_pattern(self, pattern: Pattern, subdir: Optional[Path] = None) -> Path:

        pattern_name = pattern.name
        instructions = pattern.instructions

        if subdir is None:
            path = self.base_path / f"{pattern_name}.md"
        else:
            path = self.base_path / subdir / f"{pattern_name}.md"

        path = self._normalize_path(path)

        # Check for existing pattern with same ID
        existing_path = self.get_pattern_path(pattern_name)

        if existing_path is not None and path != existing_path:
            error_msg = (
                f"Existing pattern - {pattern_name} already exists at "
                f"{existing_path.relative_to(self.base_path)}. "
                f"Attempted to access at location: {path.relative_to(self.base_path)}"
            )
            logger.error(error_msg)
            raise ValueError(error_msg)

        try:
            with self.access_manager.file_lock(path):
                write_text_to_file(path, instructions, overwrite=True)
                self.repo.update_file(path)
                logger.info(f"Pattern saved at {path}")
                return path.relative_to(self.base_path)

        except Exception as e:
            logger.error(f"Failed to save pattern {pattern.name}: {e}")
            raise

    def load_pattern(self, pattern_name: str) -> Pattern:
        """
        Load the .md pattern file by name, extract placeholders, and
        return a fully constructed Pattern object.

        Args:
            pattern_name: Name of the pattern (without .md extension).

        Returns:
            A new Pattern object whose 'instructions' is the file's text
            and whose 'template_fields' are inferred from placeholders in
            those instructions.
        """
        # Locate the .md file; raise if missing
        path = self.get_pattern_path(pattern_name)
        if not path:
            raise FileNotFoundError(f"No pattern file named {pattern_name}.md found.")

        # Acquire lock before reading
        with self.access_manager.file_lock(path):
            instructions = get_text_from_file(path)

        instructions = MarkdownStr(instructions)

        # Create the pattern from the raw .md text
        pattern = Pattern(name=pattern_name, instructions=instructions)

        # Check for local uncommitted changes, updating file:
        self.repo.update_file(path)

        return pattern

    def show_pattern_history(self, pattern_name: str) -> None:
        if path := self.get_pattern_path(pattern_name):
            self.repo.display_history(path)
        else:
            logger.error(f"Path to {pattern_name} not found.")
            return

    # def get_pattern_history_from_path(self, path: Path) -> List[Dict[str, Any]]:
    #     """
    #     Get version history for a pattern.

    #     Args:
    #         path: Path to pattern file

    #     Returns:
    #         List of version information
    #     """
    #     path = self._normalize_path(path)

    #     return self.repo.get_history(path)

    @classmethod
    def verify_repository(cls, base_path: Path) -> bool:
        """
        Verify repository integrity and uniqueness of pattern names.

        Performs the following checks:
        1. Validates Git repository structure.
        2. Ensures no duplicate pattern names exist.

        Args:
            base_path: Repository path to verify.

        Returns:
            bool: True if the repository is valid and contains no duplicate pattern files.
        """
        try:
            # Check if it's a valid Git repository
            repo = Repo(base_path)

            # Verify basic repository structure
            basic_valid = (
                repo.head.is_valid()
                and not repo.bare
                and (base_path / ".git").is_dir()
                and (base_path / ".locks").is_dir()
            )

            if not basic_valid:
                return False

            # Check for duplicate pattern names
            pattern_files = list(base_path.rglob("*.md"))
            seen_names = {}

            for pattern_file in pattern_files:
                # Skip files in .git directory
                if ".git" in pattern_file.parts:
                    continue

                # Get pattern name from the filename (without extension)
                pattern_name = pattern_file.stem

                if pattern_name in seen_names:
                    logger.error(
                        f"Duplicate pattern file detected:\n"
                        f"  First occurrence: {seen_names[pattern_name]}\n"
                        f"  Second occurrence: {pattern_file}"
                    )
                    return False

                seen_names[pattern_name] = pattern_file

            return True

        except (InvalidGitRepositoryError, Exception) as e:
            logger.error(f"Repository verification failed: {e}")
            return False
access_manager = ConcurrentAccessManager(self.base_path / '.locks') instance-attribute
base_path = Path(base_path).resolve() instance-attribute
repo = GitBackedRepository(self.base_path) instance-attribute
__init__(base_path)

Initialize pattern management system.

Parameters:

Name Type Description Default
base_path Path

Base directory for pattern storage

required
Source code in src/tnh_scholar/ai_text_processing/patterns.py
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
def __init__(self, base_path: Path):
    """
    Initialize pattern management system.

    Args:
        base_path: Base directory for pattern storage
    """
    self.base_path = Path(base_path).resolve()
    self.base_path.mkdir(parents=True, exist_ok=True)

    # Initialize subsystems
    self.repo = GitBackedRepository(self.base_path)
    self.access_manager = ConcurrentAccessManager(self.base_path / ".locks")

    logger.info(f"Initialized pattern management system at {base_path}")
get_pattern_path(pattern_name)

Recursively search for a pattern file with the given name in base_path and all subdirectories.

Parameters:

Name Type Description Default
pattern_id

pattern identifier to search for

required

Returns:

Type Description
Optional[Path]

Optional[Path]: Full path to the found pattern file, or None if not found

Source code in src/tnh_scholar/ai_text_processing/patterns.py
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
def get_pattern_path(self, pattern_name: str) -> Optional[Path]:
    """
    Recursively search for a pattern file with the given name in base_path and all subdirectories.

    Args:
        pattern_id: pattern identifier to search for

    Returns:
        Optional[Path]: Full path to the found pattern file, or None if not found
    """
    pattern = f"{pattern_name}.md"

    try:
        pattern_path = next(
            path for path in self.base_path.rglob(pattern) if path.is_file()
        )
        logger.debug(f"Found pattern file for ID {pattern_name} at: {pattern_path}")
        return self._normalize_path(pattern_path)

    except StopIteration:
        logger.debug(f"No pattern file found with ID: {pattern_name}")
        return None
load_pattern(pattern_name)

Load the .md pattern file by name, extract placeholders, and return a fully constructed Pattern object.

Parameters:

Name Type Description Default
pattern_name str

Name of the pattern (without .md extension).

required

Returns:

Type Description
Pattern

A new Pattern object whose 'instructions' is the file's text

Pattern

and whose 'template_fields' are inferred from placeholders in

Pattern

those instructions.

Source code in src/tnh_scholar/ai_text_processing/patterns.py
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
def load_pattern(self, pattern_name: str) -> Pattern:
    """
    Load the .md pattern file by name, extract placeholders, and
    return a fully constructed Pattern object.

    Args:
        pattern_name: Name of the pattern (without .md extension).

    Returns:
        A new Pattern object whose 'instructions' is the file's text
        and whose 'template_fields' are inferred from placeholders in
        those instructions.
    """
    # Locate the .md file; raise if missing
    path = self.get_pattern_path(pattern_name)
    if not path:
        raise FileNotFoundError(f"No pattern file named {pattern_name}.md found.")

    # Acquire lock before reading
    with self.access_manager.file_lock(path):
        instructions = get_text_from_file(path)

    instructions = MarkdownStr(instructions)

    # Create the pattern from the raw .md text
    pattern = Pattern(name=pattern_name, instructions=instructions)

    # Check for local uncommitted changes, updating file:
    self.repo.update_file(path)

    return pattern
save_pattern(pattern, subdir=None)
Source code in src/tnh_scholar/ai_text_processing/patterns.py
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
def save_pattern(self, pattern: Pattern, subdir: Optional[Path] = None) -> Path:

    pattern_name = pattern.name
    instructions = pattern.instructions

    if subdir is None:
        path = self.base_path / f"{pattern_name}.md"
    else:
        path = self.base_path / subdir / f"{pattern_name}.md"

    path = self._normalize_path(path)

    # Check for existing pattern with same ID
    existing_path = self.get_pattern_path(pattern_name)

    if existing_path is not None and path != existing_path:
        error_msg = (
            f"Existing pattern - {pattern_name} already exists at "
            f"{existing_path.relative_to(self.base_path)}. "
            f"Attempted to access at location: {path.relative_to(self.base_path)}"
        )
        logger.error(error_msg)
        raise ValueError(error_msg)

    try:
        with self.access_manager.file_lock(path):
            write_text_to_file(path, instructions, overwrite=True)
            self.repo.update_file(path)
            logger.info(f"Pattern saved at {path}")
            return path.relative_to(self.base_path)

    except Exception as e:
        logger.error(f"Failed to save pattern {pattern.name}: {e}")
        raise
show_pattern_history(pattern_name)
Source code in src/tnh_scholar/ai_text_processing/patterns.py
787
788
789
790
791
792
def show_pattern_history(self, pattern_name: str) -> None:
    if path := self.get_pattern_path(pattern_name):
        self.repo.display_history(path)
    else:
        logger.error(f"Path to {pattern_name} not found.")
        return
verify_repository(base_path) classmethod

Verify repository integrity and uniqueness of pattern names.

Performs the following checks: 1. Validates Git repository structure. 2. Ensures no duplicate pattern names exist.

Parameters:

Name Type Description Default
base_path Path

Repository path to verify.

required

Returns:

Name Type Description
bool bool

True if the repository is valid and contains no duplicate pattern files.

Source code in src/tnh_scholar/ai_text_processing/patterns.py
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
@classmethod
def verify_repository(cls, base_path: Path) -> bool:
    """
    Verify repository integrity and uniqueness of pattern names.

    Performs the following checks:
    1. Validates Git repository structure.
    2. Ensures no duplicate pattern names exist.

    Args:
        base_path: Repository path to verify.

    Returns:
        bool: True if the repository is valid and contains no duplicate pattern files.
    """
    try:
        # Check if it's a valid Git repository
        repo = Repo(base_path)

        # Verify basic repository structure
        basic_valid = (
            repo.head.is_valid()
            and not repo.bare
            and (base_path / ".git").is_dir()
            and (base_path / ".locks").is_dir()
        )

        if not basic_valid:
            return False

        # Check for duplicate pattern names
        pattern_files = list(base_path.rglob("*.md"))
        seen_names = {}

        for pattern_file in pattern_files:
            # Skip files in .git directory
            if ".git" in pattern_file.parts:
                continue

            # Get pattern name from the filename (without extension)
            pattern_name = pattern_file.stem

            if pattern_name in seen_names:
                logger.error(
                    f"Duplicate pattern file detected:\n"
                    f"  First occurrence: {seen_names[pattern_name]}\n"
                    f"  Second occurrence: {pattern_file}"
                )
                return False

            seen_names[pattern_name] = pattern_file

        return True

    except (InvalidGitRepositoryError, Exception) as e:
        logger.error(f"Repository verification failed: {e}")
        return False

response_format

TEXT_SECTIONS_DESCRIPTION = 'Ordered list of logical sections for the text. The sequence of line ranges for the sections must cover every line from start to finish without any overlaps or gaps.' module-attribute
LogicalSection

Bases: BaseModel

A logically coherent section of text.

Source code in src/tnh_scholar/ai_text_processing/response_format.py
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
class LogicalSection(BaseModel):
    """
    A logically coherent section of text.
    """

    title: str = Field(
        ...,
        description="Meaningful title for the section in the original language of the section.",
    )
    start_line: int = Field(
        ..., description="Starting line number of the section (inclusive)."
    )
    end_line: int = Field(
        ..., description="Ending line number of the section (inclusive)."
    )
end_line = Field(..., description='Ending line number of the section (inclusive).') class-attribute instance-attribute
start_line = Field(..., description='Starting line number of the section (inclusive).') class-attribute instance-attribute
title = Field(..., description='Meaningful title for the section in the original language of the section.') class-attribute instance-attribute
TextObject

Bases: BaseModel

Represents a text in any language broken into coherent logical sections.

Source code in src/tnh_scholar/ai_text_processing/response_format.py
29
30
31
32
33
34
35
class TextObject(BaseModel):
    """
    Represents a text in any language broken into coherent logical sections.
    """

    language: str = Field(..., description="ISO 639-1 language code of the text.")
    sections: List[LogicalSection] = Field(..., description=TEXT_SECTIONS_DESCRIPTION)
language = Field(..., description='ISO 639-1 language code of the text.') class-attribute instance-attribute
sections = Field(..., description=TEXT_SECTIONS_DESCRIPTION) class-attribute instance-attribute

typing

ResponseFormat = TypeVar('ResponseFormat', bound=BaseModel) module-attribute

audio_processing

audio

EXPECTED_TIME_FACTOR = 0.45 module-attribute
MAX_DURATION = 10 * 60 module-attribute
MAX_DURATION_MS = 10 * 60 * 1000 module-attribute
MAX_INT16 = 32768.0 module-attribute
MIN_SILENCE_LENGTH = 1000 module-attribute
SEEK_LENGTH = 50 module-attribute
SILENCE_DBFS_THRESHOLD = -30 module-attribute
logger = get_child_logger('audio_processing') module-attribute
Boundary dataclass

A data structure representing a detected audio boundary.

Attributes:

Name Type Description
start float

Start time of the segment in seconds.

end float

End time of the segment in seconds.

text str

Associated text (empty if silence-based).

Example

b = Boundary(start=0.0, end=30.0, text="Hello world") b.start, b.end, b.text (0.0, 30.0, 'Hello world')

Source code in src/tnh_scholar/audio_processing/audio.py
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
@dataclass
class Boundary:
    """A data structure representing a detected audio boundary.

    Attributes:
        start (float): Start time of the segment in seconds.
        end (float): End time of the segment in seconds.
        text (str): Associated text (empty if silence-based).

    Example:
        >>> b = Boundary(start=0.0, end=30.0, text="Hello world")
        >>> b.start, b.end, b.text
        (0.0, 30.0, 'Hello world')
    """

    start: float
    end: float
    text: str = ""
end instance-attribute
start instance-attribute
text = '' class-attribute instance-attribute
__init__(start, end, text='')
audio_to_numpy(audio_segment)

Convert an AudioSegment object to a NumPy array suitable for Whisper.

Parameters:

Name Type Description Default
audio_segment AudioSegment

The input audio segment to convert.

required

Returns:

Type Description
ndarray

np.ndarray: A mono-channel NumPy array normalized to the range [-1, 1].

Example

audio = AudioSegment.from_file("example.mp3") audio_numpy = audio_to_numpy(audio)

Source code in src/tnh_scholar/audio_processing/audio.py
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
def audio_to_numpy(audio_segment: AudioSegment) -> np.ndarray:
    """
    Convert an AudioSegment object to a NumPy array suitable for Whisper.

    Args:
        audio_segment (AudioSegment): The input audio segment to convert.

    Returns:
        np.ndarray: A mono-channel NumPy array normalized to the range [-1, 1].

    Example:
        >>> audio = AudioSegment.from_file("example.mp3")
        >>> audio_numpy = audio_to_numpy(audio)
    """
    # Convert the audio segment to raw sample data
    raw_data = np.array(audio_segment.get_array_of_samples()).astype(np.float32)

    # Normalize data to the range [-1, 1]
    raw_data /= MAX_INT16

    # Ensure mono-channel (use first channel if stereo)
    if audio_segment.channels > 1:
        raw_data = raw_data.reshape(-1, audio_segment.channels)[:, 0]

    return raw_data
detect_silence_boundaries(audio_file, min_silence_len=MIN_SILENCE_LENGTH, silence_thresh=SILENCE_DBFS_THRESHOLD, max_duration=MAX_DURATION_MS)

Detect boundaries (start/end times) based on silence detection.

Parameters:

Name Type Description Default
audio_file Path

Path to the audio file.

required
min_silence_len int

Minimum silence length to consider for splitting (ms).

MIN_SILENCE_LENGTH
silence_thresh int

Silence threshold in dBFS.

SILENCE_DBFS_THRESHOLD
max_duration int

Maximum duration of any segment (ms).

MAX_DURATION_MS

Returns:

Type Description
Tuple[List[Boundary], Dict]

List[Boundary]: A list of boundaries with empty text.

Example

boundaries = detect_silence_boundaries(Path("my_audio.mp3")) for b in boundaries: ... print(b.start, b.end)

Source code in src/tnh_scholar/audio_processing/audio.py
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
def detect_silence_boundaries(
    audio_file: Path,
    min_silence_len: int = MIN_SILENCE_LENGTH,
    silence_thresh: int = SILENCE_DBFS_THRESHOLD,
    max_duration: int = MAX_DURATION_MS,
) -> Tuple[List[Boundary], Dict]:
    """
    Detect boundaries (start/end times) based on silence detection.

    Args:
        audio_file (Path): Path to the audio file.
        min_silence_len (int): Minimum silence length to consider for splitting (ms).
        silence_thresh (int): Silence threshold in dBFS.
        max_duration (int): Maximum duration of any segment (ms).

    Returns:
        List[Boundary]: A list of boundaries with empty text.

    Example:
        >>> boundaries = detect_silence_boundaries(Path("my_audio.mp3"))
        >>> for b in boundaries:
        ...     print(b.start, b.end)
    """
    logger.debug(
        f"Detecting silence boundaries with min_silence={min_silence_len}, silence_thresh={silence_thresh}"
    )

    audio = AudioSegment.from_file(audio_file)
    nonsilent_ranges = detect_nonsilent(
        audio,
        min_silence_len=min_silence_len,
        silence_thresh=silence_thresh,
        seek_step=SEEK_LENGTH,
    )

    # Combine ranges to enforce max_duration
    if not nonsilent_ranges:
        # If no nonsilent segments found, return entire file as one boundary
        duration_s = len(audio) / 1000.0
        return [Boundary(start=0.0, end=duration_s, text="")]

    combined_ranges = []
    current_start, current_end = nonsilent_ranges[0]
    for start, end in nonsilent_ranges[1:]:
        if (current_end - current_start) + (end - start) <= max_duration:
            # Extend the current segment
            current_end = end
        else:
            combined_ranges.append((current_start, current_end))
            current_start, current_end = start, end
    combined_ranges.append((current_start, current_end))

    return [
        Boundary(start=start_ms / 1000.0, end=end_ms / 1000.0, text="")
        for start_ms, end_ms in combined_ranges
    ]
detect_whisper_boundaries(audio_file, model_size='tiny', language=None)

Detect sentence boundaries using a Whisper model.

Parameters:

Name Type Description Default
audio_file Path

Path to the audio file.

required
model_size str

Whisper model size.

'tiny'
language str

Language to force for transcription (e.g. 'en', 'vi'), or None for auto.

None

Returns:

Type Description
List[Boundary]

List[Boundary]: A list of sentence boundaries with text.

Example

boundaries = detect_whisper_boundaries(Path("my_audio.mp3"), model_size="tiny") for b in boundaries: ... print(b.start, b.end, b.text)

Source code in src/tnh_scholar/audio_processing/audio.py
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
def detect_whisper_boundaries(
    audio_file: Path, model_size: str = "tiny", language: str = None
) -> List[Boundary]:
    """
    Detect sentence boundaries using a Whisper model.

    Args:
        audio_file (Path): Path to the audio file.
        model_size (str): Whisper model size.
        language (str): Language to force for transcription (e.g. 'en', 'vi'), or None for auto.

    Returns:
        List[Boundary]: A list of sentence boundaries with text.

    Example:
        >>> boundaries = detect_whisper_boundaries(Path("my_audio.mp3"), model_size="tiny")
        >>> for b in boundaries:
        ...     print(b.start, b.end, b.text)
    """

    os.environ["KMP_WARNINGS"] = "0"  # Turn of OMP warning message

    # Load model
    logger.info("Loading Whisper model...")
    model = load_whisper_model(model_size)
    logger.info(f"Model '{model_size}' loaded.")

    if language:
        logger.info(f"Language for boundaries set to '{language}'")
    else:
        logger.info("Language not set. Autodetect will be used in Whisper model.")

    # with TimeProgress(expected_time=expected_time, desc="Generating transcription boundaries"):
    boundary_transcription = whisper_model_transcribe(
        model,
        str(audio_file),
        task="transcribe",
        word_timestamps=True,
        language=language,
        verbose=False,
    )

    sentence_boundaries = [
        Boundary(start=segment["start"], end=segment["end"], text=segment["text"])
        for segment in boundary_transcription["segments"]
    ]
    return sentence_boundaries, boundary_transcription
split_audio(audio_file, method='whisper', output_dir=None, model_size='tiny', language=None, min_silence_len=MIN_SILENCE_LENGTH, silence_thresh=SILENCE_DBFS_THRESHOLD, max_duration=MAX_DURATION)

High-level function to split an audio file into chunks based on a chosen method.

Parameters:

Name Type Description Default
audio_file Path

The input audio file.

required
method str

Splitting method, "silence" or "whisper".

'whisper'
output_dir Path

Directory to store output.

None
model_size str

Whisper model size if method='whisper'.

'tiny'
language str

Language for whisper transcription if method='whisper'.

None
min_silence_len int

For silence-based detection, min silence length in ms.

MIN_SILENCE_LENGTH
silence_thresh int

Silence threshold in dBFS.

SILENCE_DBFS_THRESHOLD
max_duration_s int

Max chunk length in seconds.

required
max_duration_ms int

Max chunk length in ms (for silence detection combination).

required

Returns:

Name Type Description
Path Path

Directory containing the resulting chunks.

Example
Split using silence detection

split_audio(Path("my_audio.mp3"), method="silence")

Split using whisper-based sentence boundaries

split_audio(Path("my_audio.mp3"), method="whisper", model_size="base", language="en")

Source code in src/tnh_scholar/audio_processing/audio.py
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
def split_audio(
    audio_file: Path,
    method: str = "whisper",
    output_dir: Path = None,
    model_size: str = "tiny",
    language: str = None,
    min_silence_len: int = MIN_SILENCE_LENGTH,
    silence_thresh: int = SILENCE_DBFS_THRESHOLD,
    max_duration: int = MAX_DURATION,
) -> Path:
    """
    High-level function to split an audio file into chunks based on a chosen method.

    Args:
        audio_file (Path): The input audio file.
        method (str): Splitting method, "silence" or "whisper".
        output_dir (Path): Directory to store output.
        model_size (str): Whisper model size if method='whisper'.
        language (str): Language for whisper transcription if method='whisper'.
        min_silence_len (int): For silence-based detection, min silence length in ms.
        silence_thresh (int): Silence threshold in dBFS.
        max_duration_s (int): Max chunk length in seconds.
        max_duration_ms (int): Max chunk length in ms (for silence detection combination).

    Returns:
        Path: Directory containing the resulting chunks.

    Example:
        >>> # Split using silence detection
        >>> split_audio(Path("my_audio.mp3"), method="silence")

        >>> # Split using whisper-based sentence boundaries
        >>> split_audio(Path("my_audio.mp3"), method="whisper", model_size="base", language="en")
    """

    logger.info(f"Splitting audio with max_duration={max_duration} seconds")

    if method == "whisper":
        boundaries, _ = detect_whisper_boundaries(
            audio_file, model_size=model_size, language=language
        )

    elif method == "silence":
        max_duration_ms = (
            max_duration * 1000
        )  # convert duration in seconds to milliseconds
        boundaries = detect_silence_boundaries(
            audio_file,
            min_silence_len=min_silence_len,
            silence_thresh=silence_thresh,
            max_duration=max_duration_ms,
        )
    else:
        raise ValueError(f"Unknown method: {method}. Must be 'silence' or 'whisper'.")

    # delete all files in the output_dir (this is useful for reprocessing)

    return split_audio_at_boundaries(
        audio_file, boundaries, output_dir=output_dir, max_duration=max_duration
    )
split_audio_at_boundaries(audio_file, boundaries, output_dir=None, max_duration=MAX_DURATION)

Split the audio file into chunks based on provided boundaries, ensuring all audio is included and boundaries align with the start of Whisper segments.

Parameters:

Name Type Description Default
audio_file Path

The input audio file.

required
boundaries List[Boundary]

Detected boundaries.

required
output_dir Path

Directory to store the resulting chunks.

None
max_duration int

Maximum chunk length in seconds.

MAX_DURATION

Returns:

Name Type Description
Path Path

Directory containing the chunked audio files.

Example

boundaries = [Boundary(34.02, 37.26, "..."), Boundary(38.0, 41.18, "...")] out_dir = split_audio_at_boundaries(Path("my_audio.mp3"), boundaries)

Source code in src/tnh_scholar/audio_processing/audio.py
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
def split_audio_at_boundaries(
    audio_file: Path,
    boundaries: List[Boundary],
    output_dir: Path = None,
    max_duration: int = MAX_DURATION,
) -> Path:
    """
    Split the audio file into chunks based on provided boundaries, ensuring all audio is included
    and boundaries align with the start of Whisper segments.

    Args:
        audio_file (Path): The input audio file.
        boundaries (List[Boundary]): Detected boundaries.
        output_dir (Path): Directory to store the resulting chunks.
        max_duration (int): Maximum chunk length in seconds.

    Returns:
        Path: Directory containing the chunked audio files.

    Example:
        >>> boundaries = [Boundary(34.02, 37.26, "..."), Boundary(38.0, 41.18, "...")]
        >>> out_dir = split_audio_at_boundaries(Path("my_audio.mp3"), boundaries)
    """
    logger.info(f"Splitting audio with max_duration={max_duration} seconds")

    # Load the audio file
    audio = AudioSegment.from_file(audio_file)

    # Create output directory based on filename
    if output_dir is None:
        output_dir = audio_file.parent / f"{audio_file.stem}_chunks"
    output_dir.mkdir(parents=True, exist_ok=True)

    # Clean up the output directory
    for file in output_dir.iterdir():
        if file.is_file():
            logger.info(f"Deleting existing file: {file}")
            file.unlink()

    chunk_start = 0  # Start time for the first chunk in ms
    chunk_count = 1
    current_chunk = AudioSegment.empty()

    for idx, boundary in enumerate(boundaries):
        segment_start_ms = int(boundary.start * 1000)
        if idx + 1 < len(boundaries):
            segment_end_ms = int(
                boundaries[idx + 1].start * 1000
            )  # Next boundary's start
        else:
            segment_end_ms = len(audio)  # End of the audio for the last boundary

        # Adjust for the first segment starting at 0
        if idx == 0 and segment_start_ms > 0:
            segment_start_ms = 0  # Ensure we include the very beginning of the audio

        segment = audio[segment_start_ms:segment_end_ms]

        logger.debug(
            f"Boundary index: {idx}, segment_start: {segment_start_ms / 1000}, segment_end: {segment_end_ms / 1000}, duration: {segment.duration_seconds}"
        )
        logger.debug(f"Current chunk Duration (s): {current_chunk.duration_seconds}")

        if len(current_chunk) + len(segment) <= max_duration * 1000:
            # Add segment to the current chunk
            current_chunk += segment
        else:
            # Export current chunk
            chunk_path = output_dir / f"chunk_{chunk_count}.mp3"
            current_chunk.export(chunk_path, format="mp3")
            logger.info(f"Exported: {chunk_path}")
            chunk_count += 1

            # Start a new chunk with the current segment
            current_chunk = segment

    # Export the final chunk if any audio remains
    if len(current_chunk) > 0:
        chunk_path = output_dir / f"chunk_{chunk_count}.mp3"
        current_chunk.export(chunk_path, format="mp3")
        logger.info(f"Exported: {chunk_path}")

    return output_dir
whisper_model_transcribe(model, input_source, *args, **kwargs)

Wrapper around model.transcribe that suppresses the known 'FP16 is not supported on CPU; using FP32 instead' UserWarning and redirects unwanted 'OMP' messages to prevent interference.

This function accepts all args and kwargs that model.transcribe normally does, and supports input sources as file paths (str or Path) or in-memory audio arrays.

Parameters:

Name Type Description Default
model Any

The Whisper model instance.

required
input_source Union[str, Path, ndarray]

Input audio file path, URL, or in-memory audio array.

required
*args

Additional positional arguments for model.transcribe.

()
**kwargs

Additional keyword arguments for model.transcribe.

{}

Returns:

Type Description
Dict[str, Any]

Dict[str, Any]: Transcription result from model.transcribe.

Example
Using a file path

result = whisper_model_transcribe(my_model, "sample_audio.mp3", verbose=True)

Using an audio array

result = whisper_model_transcribe(my_model, audio_array, language="en")

Source code in src/tnh_scholar/audio_processing/audio.py
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
def whisper_model_transcribe(
    model: Any,
    input_source: Any,
    *args,
    **kwargs,
) -> Dict[str, Any]:
    """
    Wrapper around model.transcribe that suppresses the known
    'FP16 is not supported on CPU; using FP32 instead' UserWarning
    and redirects unwanted 'OMP' messages to prevent interference.

    This function accepts all args and kwargs that model.transcribe normally does,
    and supports input sources as file paths (str or Path) or in-memory audio arrays.

    Parameters:
        model (Any): The Whisper model instance.
        input_source (Union[str, Path, np.ndarray]): Input audio file path, URL, or in-memory audio array.
        *args: Additional positional arguments for model.transcribe.
        **kwargs: Additional keyword arguments for model.transcribe.

    Returns:
        Dict[str, Any]: Transcription result from model.transcribe.

    Example:
        # Using a file path
        result = whisper_model_transcribe(my_model, "sample_audio.mp3", verbose=True)

        # Using an audio array
        result = whisper_model_transcribe(my_model, audio_array, language="en")
    """

    # class StdoutFilter(io.StringIO):
    #     def __init__(self, original_stdout):
    #         super().__init__()
    #         self.original_stdout = original_stdout

    #     def write(self, message):
    #         # Suppress specific messages like 'OMP:' while allowing others
    #         if "OMP:" not in message:
    #             self.original_stdout.write(message)

    with warnings.catch_warnings():
        warnings.filterwarnings(
            "ignore",
            message="FP16 is not supported on CPU; using FP32 instead",
            category=UserWarning,
        )

        # Redirect stdout to suppress OMP messages
        # original_stdout = sys.stdout
        # sys.stdout = filtered_stdout

        try:
            # Convert Path to str if needed
            if isinstance(input_source, Path):
                input_source = str(input_source)

            # Call the original transcribe function
            return model.transcribe(input_source, *args, **kwargs)
        finally:
            # Restore original stdout
            # sys.stdout = original_stdout
            pass

transcription

logger = get_child_logger(__name__) module-attribute
custom_to_json(transcript)

Custom JSON conversion function to handle problematic float values from Open AI API interface.

Parameters:

Name Type Description Default
transcript Any

Object from OpenAI API's transcription.

required

Returns:

Name Type Description
str str

JSON string with problematic values fixed.

Source code in src/tnh_scholar/audio_processing/transcription.py
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
def custom_to_json(transcript: TranscriptionVerbose) -> str:
    """
    Custom JSON conversion function to handle problematic float values from Open AI API interface.

    Args:
        transcript (Any): Object from OpenAI API's transcription.

    Returns:
        str: JSON string with problematic values fixed.
    """
    logger.debug("Entered custom_to_json function.")
    try:
        # Use warnings.catch_warnings to catch specific warnings
        with warnings.catch_warnings(record=True) as caught_warnings:
            warnings.simplefilter("always", UserWarning)  # Catch all UserWarnings
            data = transcript.to_dict()

            # Check if any warnings were caught
            for warning in caught_warnings:
                if issubclass(warning.category, UserWarning):
                    warning_msg = str(warning.message)
                    if "Expected `str` but got `float`" in warning_msg:
                        logger.debug(
                            "Known UserWarning in OPENAI .to_dict() float serialization caught and ignored."
                        )
                    else:
                        logger.warning(
                            f"Unexpected warning during to_dict(): {warning_msg}"
                        )
    except Exception as e:
        logger.error(f"Error during to_dict(): {e}", exc_info=True)
        return json.dumps({})  # Return an empty JSON as a fallback

    # Traverse the dictionary to convert problematic floats to strings
    for key, value in data.items():
        if isinstance(value, float):  # Handle floats
            data[key] = float(f"{value:.18f}")

    # Serialize the cleaned dictionary back to JSON
    logger.debug("Dumping json in custom_to_json...")
    return json.dumps(data)
get_text_from_transcript(transcript)
Extracts and combines text from all segments of a transcription.

Args:
    transcript (TranscriptionVerbose): A transcription object containing segments of text.

Returns:
    str: A single string with all segment texts concatenated, separated by newlines.

Raises:
    ValueError: If the transcript object is invalid or missing required attributes.

Example:
    >>> from openai.types.audio.transcription_verbose import TranscriptionVerbose
    >>> transcript = TranscriptionVerbose(segments=[{"text": "Hello"}, {"text": "world"}])
    >>> get_text_from_transcript(transcript)
    'Hello

world'

Source code in src/tnh_scholar/audio_processing/transcription.py
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
def get_text_from_transcript(transcript: TranscriptionVerbose) -> str:
    """
    Extracts and combines text from all segments of a transcription.

    Args:
        transcript (TranscriptionVerbose): A transcription object containing segments of text.

    Returns:
        str: A single string with all segment texts concatenated, separated by newlines.

    Raises:
        ValueError: If the transcript object is invalid or missing required attributes.

    Example:
        >>> from openai.types.audio.transcription_verbose import TranscriptionVerbose
        >>> transcript = TranscriptionVerbose(segments=[{"text": "Hello"}, {"text": "world"}])
        >>> get_text_from_transcript(transcript)
        'Hello\nworld'
    """
    logger.debug(f"transcript is type: {type(transcript)}")

    return "\n".join(segment.text.strip() for segment in transcript.segments)
get_transcription(file, model, prompt, jsonl_out, mode='transcribe')
Source code in src/tnh_scholar/audio_processing/transcription.py
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
def get_transcription(
    file: Path, model: str, prompt: str, jsonl_out, mode="transcribe"
):
    logger.info(
        f"Speech transcript parameters: file={file}, model={model}, response_format=verbose_json, mode={mode}\n\tprompt='{prompt}'"
    )
    transcript = run_transcription_speech(
        file, model=model, response_format="verbose_json", prompt=prompt, mode=mode
    )

    # Use the custom_to_json function
    json_output = custom_to_json(transcript)
    logger.debug(f"Serialized JSON output excerpt: {json_output[:1000]}...")

    # Write the serialized JSON to the JSONL file
    jsonl_out.write(json_output + "\n")

    return get_text_from_transcript(transcript)
process_audio_chunks(directory, output_file, jsonl_file, model='whisper-1', prompt='', translate=False)

Processes all audio chunks in the specified directory using OpenAI's transcription API, saves the transcription objects into a JSONL file, and stitches the transcriptions into a single text file.

Parameters:

Name Type Description Default
directory Path

Path to the directory containing audio chunks.

required
output_file Path

Path to the output file to save the stitched transcription.

required
jsonl_file Path

Path to save the transcription objects as a JSONL file.

required
model str

The transcription model to use (default is "whisper-1").

'whisper-1'
prompt str

Optional prompt to provide context for better transcription.

''
translate bool

Optional flag to translate speech to English (useful if the audio input is not English)

False

Raises: FileNotFoundError: If no audio chunks are found in the directory.

Source code in src/tnh_scholar/audio_processing/transcription.py
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
def process_audio_chunks(
    directory: Path,
    output_file: Path,
    jsonl_file: Path,
    model: str = "whisper-1",
    prompt: str = "",
    translate: bool = False,
) -> None:
    """
    Processes all audio chunks in the specified directory using OpenAI's transcription API,
    saves the transcription objects into a JSONL file, and stitches the transcriptions
    into a single text file.

    Args:
        directory (Path): Path to the directory containing audio chunks.
        output_file (Path): Path to the output file to save the stitched transcription.
        jsonl_file (Path): Path to save the transcription objects as a JSONL file.
        model (str): The transcription model to use (default is "whisper-1").
        prompt (str): Optional prompt to provide context for better transcription.
        translate (bool): Optional flag to translate speech to English (useful if the audio input is not English)
    Raises:
        FileNotFoundError: If no audio chunks are found in the directory.
    """

    # Ensure the output directory exists
    output_file.parent.mkdir(parents=True, exist_ok=True)
    jsonl_file.parent.mkdir(parents=True, exist_ok=True)

    # Collect all audio chunks in the directory, sorting numerically by chunk number
    audio_files = sorted(
        directory.glob("*.mp3"),
        key=lambda f: int(f.stem.split("_")[1]),  # Extract the number from 'chunk_X'
    )

    if not audio_files:
        raise FileNotFoundError(f"No audio files found in the directory: {directory}")

    # log files to process:
    audio_file_names = [file.name for file in audio_files]  # get strings for logging
    audio_file_name_str = "\n\t".join(audio_file_names)
    audio_file_count = len(audio_file_names)
    logger.info(
        f"{audio_file_count} audio files found in {directory}:\n\t{audio_file_name_str}"
    )

    # Initialize the output content
    stitched_transcription = []

    # Open the JSONL file for writing
    with jsonl_file.open("w", encoding="utf-8") as jsonl_out:
        # Process each audio chunk
        for audio_file in audio_files:
            logger.info(f"Processing {audio_file.name}...")
            try:
                if translate:
                    text = get_transcription(
                        audio_file, model, prompt, jsonl_out, mode="translate"
                    )
                else:
                    text = get_transcription(
                        audio_file, model, prompt, jsonl_out, mode="transcribe"
                    )

                stitched_transcription.append(text)

            except Exception as e:
                logger.error(f"Error processing {audio_file.name}: {e}", exc_info=True)
                raise e

    # Write the stitched transcription to the output file
    with output_file.open("w", encoding="utf-8") as out_file:
        out_file.write(" ".join(stitched_transcription))

    logger.info(f"Stitched transcription saved to {output_file}")
    logger.info(f"Full transcript objects saved to {jsonl_file}")
process_audio_file(audio_file, output_file, jsonl_file, model='whisper-1', prompt='', translate=False)

Processes a single audio file using OpenAI's transcription API, saves the transcription objects into a JSONL file.

Parameters:

Name Type Description Default
audio_file Path

Path to the the audio file for processing

required
output_file Path

Path to the output file to save the stitched transcription.

required
jsonl_file Path

Path to save the transcription objects as a JSONL file.

required
model str

The transcription model to use (default is "whisper-1").

'whisper-1'
prompt str

Optional prompt to provide context for better transcription.

''
translate bool

Optional flag to translate speech to English (useful if the audio input is not English)

False

Raises: FileNotFoundError: If no audio chunks are found in the directory.

Source code in src/tnh_scholar/audio_processing/transcription.py
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
def process_audio_file(
    audio_file: Path,
    output_file: Path,
    jsonl_file: Path,
    model: str = "whisper-1",
    prompt: str = "",
    translate: bool = False,
) -> None:
    """
    Processes a single audio file using OpenAI's transcription API,
    saves the transcription objects into a JSONL file.

    Args:
        audio_file (Path): Path to the the audio file for processing
        output_file (Path): Path to the output file to save the stitched transcription.
        jsonl_file (Path): Path to save the transcription objects as a JSONL file.
        model (str): The transcription model to use (default is "whisper-1").
        prompt (str): Optional prompt to provide context for better transcription.
        translate (bool): Optional flag to translate speech to English (useful if the audio input is not English)
    Raises:
        FileNotFoundError: If no audio chunks are found in the directory.
    """

    # Ensure the output directory exists
    output_file.parent.mkdir(parents=True, exist_ok=True)
    jsonl_file.parent.mkdir(parents=True, exist_ok=True)

    if not audio_file.exists():
        raise FileNotFoundError(f"Audio file {audio_file} not found.")
    else:
        logger.info(f"Audio file found: {audio_file}")

    # Open the JSONL file for writing
    with jsonl_file.open("w", encoding="utf-8") as jsonl_out:
        logger.info(f"Processing {audio_file.name}...")
        try:
            if translate:
                text = get_transcription(
                    audio_file, model, prompt, jsonl_out, mode="translate"
                )
            else:
                text = get_transcription(
                    audio_file, model, prompt, jsonl_out, mode="transcribe"
                )
        except Exception as e:
            logger.error(f"Error processing {audio_file.name}: {e}", exc_info=True)
            raise e

    # Write the stitched transcription to the output file
    with output_file.open("w", encoding="utf-8") as out_file:
        out_file.write(text)

    logger.info(f"Transcription saved to {output_file}")
    logger.info(f"Full transcript objects saved to {jsonl_file}")

whisper_security

logger = get_child_logger(__name__) module-attribute
load_whisper_model(model_name)

Safely load a Whisper model with security best practices.

Parameters:

Name Type Description Default
model_name str

Name of the Whisper model to load (e.g., "tiny", "base", "small")

required

Returns:

Type Description
Any

Loaded Whisper model

Raises:

Type Description
RuntimeError

If model loading fails

Source code in src/tnh_scholar/audio_processing/whisper_security.py
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
def load_whisper_model(model_name: str) -> Any:
    """
    Safely load a Whisper model with security best practices.

    Args:
        model_name: Name of the Whisper model to load (e.g., "tiny", "base", "small")

    Returns:
        Loaded Whisper model

    Raises:
        RuntimeError: If model loading fails
    """
    import whisper

    try:
        with safe_torch_load():
            model = whisper.load_model(model_name)
        return model
    except Exception as e:
        logger.error("Failed to load Whisper model %r: %s", model_name, e)
        raise RuntimeError(f"Failed to load Whisper model: {e}") from e
safe_torch_load(weights_only=True)

Context manager that temporarily modifies torch.load to use weights_only=True by default.

This addresses the FutureWarning in PyTorch regarding pickle security: https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models

Parameters:

Name Type Description Default
weights_only bool

If True, limits unpickling to tensor data only.

True

Yields:

Type Description
None

None

Example

with safe_torch_load(): ... model = whisper.load_model("tiny")

Source code in src/tnh_scholar/audio_processing/whisper_security.py
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
@contextlib.contextmanager
def safe_torch_load(weights_only: bool = True) -> Generator[None, None, None]:
    """
    Context manager that temporarily modifies torch.load to use weights_only=True by default.

    This addresses the FutureWarning in PyTorch regarding pickle security:
    https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models

    Args:
        weights_only: If True, limits unpickling to tensor data only.

    Yields:
        None

    Example:
        >>> with safe_torch_load():
        ...     model = whisper.load_model("tiny")
    """
    original_torch_load = torch.load
    try:
        torch.load = partial(original_torch_load, weights_only=weights_only)
        logger.debug("Modified torch.load to use weights_only=%s", weights_only)
        yield
    finally:
        torch.load = original_torch_load
        logger.debug("Restored original torch.load")

cli_tools

TNH Scholar CLI Tools

Command-line interface tools for the TNH Scholar project:

audio-transcribe:
    Audio processing pipeline that handles downloading, segmentation,
    and transcription of Buddhist teachings.

tnh-fab:
    Text processing tool for texts, providing functionality for
    punctuation, sectioning, translation, and pattern-based processing.

See individual tool documentation for usage details and examples.

audio_transcribe

audio_transcribe

This module provides a command line interface for handling audio transcription tasks. It can optionally: - Download audio from a YouTube URL. - Split existing audio into chunks. - Transcribe audio chunks to text.

Usage Example
Download, split, and transcribe from a single YouTube URL

audio-transcribe --yt_download --yt_process_url "https://www.youtube.com/watch?v=EXAMPLE" --split --transcribe --output_dir ./processed --prompt "Dharma, Deer Park..."

In a production environment, this CLI tool would be installed as part of the tnh-scholar package.

DEFAULT_CHUNK_DURATION_MIN = 7 module-attribute
DEFAULT_CHUNK_DURATION_SEC = DEFAULT_CHUNK_DURATION_MIN * 60 module-attribute
DEFAULT_OUTPUT_DIR = './audio_transcriptions' module-attribute
DEFAULT_PROMPT = 'Dharma, Deer Park, Thay, Thich Nhat Hanh, Bodhicitta, Bodhisattva, Mahayana' module-attribute
REQUIREMENTS_PATH = TNH_CLI_TOOLS_DIR / 'audio_transcribe' / 'environment' / 'requirements.txt' module-attribute
RE_DOWNLOAD_CONFIRMATION_STR = 'An mp3 file corresponding to {url} already exists in the output path:\n\t{output_dir}.\nSKIP download ([Y]/n)?' module-attribute
logger = get_child_logger('audio_transcribe') module-attribute
audio_transcribe(split, transcribe, yt_url, yt_url_csv, file, chunk_dir, output_dir, chunk_duration, no_chunks, start_time, translate, prompt, silence_boundaries, whisper_boundaries, language)

Entry point for the audio transcription pipeline. Depending on the provided flags and arguments, it can download audio from YouTube, split the audio into chunks, and/or transcribe the chunks.

Steps are:

  1. Download (if requested)

  2. Split (if requested)

  3. Transcribe (if requested)

Source code in src/tnh_scholar/cli_tools/audio_transcribe/audio_transcribe.py
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
@click.command()
@click.option(
    "-s", "--split", is_flag=True, help="Split downloaded/local audio into chunks."
)
@click.option("-t", "--transcribe", is_flag=True, help="Transcribe the audio chunks.")
@click.option("-y", "--yt_url", type=str, help="Single YouTube URL to process.")
@click.option(
    "-v",
    "--yt_url_csv",
    type=click.Path(exists=True),
    help="A CSV File containing multiple YouTube URLs. The first column of the file must be the URL and Second column a start time (if specified).",
)
@click.option(
    "-f", "--file", type=click.Path(exists=True), help="Path to a local audio file."
)
@click.option(
    "-c",
    "--chunk_dir",
    type=click.Path(),
    help="Directory for pre-existing chunks or where to store new chunks.",
)
@click.option(
    "-o",
    "--output_dir",
    type=click.Path(),
    default=DEFAULT_OUTPUT_DIR,
    help=f"Base output directory. DEFAULT: '{DEFAULT_OUTPUT_DIR}' ",
)
@click.option(
    "-d",
    "--chunk_duration",
    type=int,
    default=DEFAULT_CHUNK_DURATION_SEC,
    help=f"Max chunk duration in seconds (default: {DEFAULT_CHUNK_DURATION_MIN} minutes).",
)
@click.option(
    "-x",
    "--no_chunks",
    is_flag=True,
    help="Run transcription directly on the audio files source(s). WARNING: for files > 10 minutes in Length, the Open AI transcription API may fail.",
)
@click.option(
    "-b",
    "--start",
    "start_time",
    type=str,
    help="Start time (beginning) offset for the input media (HH:MM:SS).",
)
@click.option(
    "-a",
    "--translate",
    is_flag=True,
    help="Include translation in the transcription if set.",
)
@click.option(
    "-p",
    "--prompt",
    type=str,
    default=DEFAULT_PROMPT,
    help="Prompt or keywords to guide the transcription.",
)
@click.option(
    "-i",
    "--silence_boundaries",
    is_flag=True,
    help="Use silence detection to split audio file(s)",
)
@click.option(
    "-w",
    "--whisper_boundaries",
    is_flag=True,
    help="(DEFAULT) Use a whisper based model to audio at sentence boundaries.",
)
@click.option(
    "-l",
    "--language",
    type=str,
    help="The two letter language code. e.g. 'vi' for Vietnamese. Used for splitting only. DEFAULT: English ('en').",
)
def audio_transcribe(
    split: bool,
    transcribe: bool,
    yt_url: str | None,
    yt_url_csv: str | None,
    file: str | None,
    chunk_dir: str | None,
    output_dir: str,
    chunk_duration: int,
    no_chunks: bool,
    start_time: str | None,
    translate: bool,
    prompt: str,
    silence_boundaries: bool,
    whisper_boundaries: bool,
    language: str | None,
) -> None:
    """
    Entry point for the audio transcription pipeline.
    Depending on the provided flags and arguments, it can download audio from YouTube,
    split the audio into chunks, and/or transcribe the chunks.

    Steps are:

    1. Download (if requested)

    2. Split (if requested)

    3. Transcribe (if requested)
    """

    check_ytd_version()  # Do a version check on startup. Version issues can cause yt-dlp to fail.

    logger.info("Starting audio transcription pipeline...")

    # initial parameter processing
    if not split and not transcribe:  # if neither set, we assume both.
        split = True
        transcribe = True

    is_download = bool(yt_url or yt_url_csv)
    if not language:
        language = "en"

    # default logic for splitting boundaries
    if not whisper_boundaries and not silence_boundaries:
        whisper_boundaries = True

    try:
        # Validate input arguments
        audio_file: Path | None = Path(file) if file else None
        chunk_directory: Path | None = Path(chunk_dir) if chunk_dir else None
        out_dir = Path(output_dir)

        validate_inputs(
            is_download=is_download,
            yt_url=yt_url,
            yt_url_list=Path(yt_url_csv) if yt_url_csv else None,
            audio_file=audio_file,
            split=split,
            transcribe=transcribe,
            chunk_dir=chunk_directory,
            no_chunks=no_chunks,
            silence_boundaries=silence_boundaries,
            whisper_boundaries=whisper_boundaries,
        )

        # Determine the list of URLs if we are downloading from YouTube
        urls: list[str] = []
        if yt_url_csv:
            if is_download:
                urls = get_youtube_urls_from_csv(Path(yt_url_csv))
        elif yt_url:
            if is_download:
                urls = [yt_url]

        # If we are downloading from YouTube, handle that
        downloaded_files: list[Path] = []
        if is_download:
            for url in urls:
                download_path = get_video_download_path_yt(out_dir, url)
                if download_path.exists():
                    if get_user_confirmation(
                        RE_DOWNLOAD_CONFIRMATION_STR.format(url=url, output_dir=out_dir)
                    ):
                        logger.info(f"Skipping download for {url}.")
                    else:
                        logger.info(f"Re-downloading {url}:")
                        download_path = download_audio_yt(
                            url, out_dir, start_time=start_time
                        )
                        logger.info(f"Successfully downloaded {url} to {download_path}")
                else:
                    logger.info(f"Downloading from YouTube: {url}")
                    ensure_directory_exists(out_dir)
                    download_path = download_audio_yt(
                        url, out_dir, start_time=start_time
                    )
                    logger.info(f"Successfully downloaded {url} to {download_path}")

                downloaded_files.append(download_path)

        # If we have a local audio file specified (no yt_download), treat that as our input
        if audio_file and not is_download:
            downloaded_files = [audio_file]

        # If splitting is requested, split either the downloaded files or the provided audio
        if split:
            for audio_file in downloaded_files:
                audio_name = audio_file.stem
                audio_output_dir = out_dir / audio_name
                ensure_directory_exists(audio_output_dir)
                chunk_output_dir = chunk_directory or audio_output_dir / "chunks"
                ensure_directory_exists(chunk_output_dir)

                logger.info(f"Splitting audio into chunks for {audio_file}")

                if (
                    not whisper_boundaries
                    and not silence_boundaries
                    or not silence_boundaries
                ):
                    detection_method = "whisper"
                else:
                    detection_method = "silence"
                split_audio(
                    audio_file=audio_file,
                    method=detection_method,
                    output_dir=chunk_output_dir,
                    max_duration=chunk_duration,
                    language=language,
                )

        # If transcribe is requested, we must have a chunk directory to transcribe from
        if transcribe:
            for audio_file in downloaded_files:
                audio_name = audio_file.stem
                audio_output_dir = out_dir / audio_name
                transcript_file = audio_output_dir / f"{audio_name}.txt"
                if no_chunks:
                    jsonl_file = audio_output_dir / f"{audio_name}.jsonl"
                    logger.info(
                        f"Transcribing {audio_name} directly without chunking..."
                    )
                    process_audio_file(
                        audio_file=audio_file,
                        output_file=transcript_file,
                        jsonl_file=jsonl_file,
                        prompt=prompt,
                        translate=translate,
                    )

                else:
                    chunk_output_dir = chunk_directory or audio_output_dir / "chunks"
                    jsonl_file = audio_output_dir / f"{audio_name}.jsonl"
                    logger.info(f"Transcribing chunks from {chunk_output_dir}")
                    process_audio_chunks(
                        directory=chunk_output_dir,
                        output_file=transcript_file,
                        jsonl_file=jsonl_file,
                        prompt=prompt,
                        translate=translate,
                    )

        logger.info("Audio transcription pipeline completed successfully.")
    except Exception as e:
        logger.error(f"An error occurred: {e}")
        logger.debug("traceback info", exc_info=True)
        sys.exit(1)
main()

Entry point for AUDIO-TRANSCRIBE CLI tool.

Source code in src/tnh_scholar/cli_tools/audio_transcribe/audio_transcribe.py
327
328
329
def main():
    """Entry point for AUDIO-TRANSCRIBE CLI tool."""
    audio_transcribe()
environment
env
logger = get_child_logger(__name__) module-attribute
check_env()

Check the environment for necessary conditions: 1. Check OpenAI key is available. 2. Check that all requirements from requirements.txt are importable.

Source code in src/tnh_scholar/cli_tools/audio_transcribe/environment/env.py
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
def check_env() -> bool:
    """
    Check the environment for necessary conditions:
    1. Check OpenAI key is available.
    2. Check that all requirements from requirements.txt are importable.
    """
    logger.debug("checking environment.")

    if not check_openai_env():
        return False

    if shutil.which("ffmpeg") is None:
        logger.error("ffmpeg not found in PATH. ffmpeg required for audio processing.")
        return False

    return True
check_requirements(requirements_file)

Check that all requirements listed in requirements.txt can be imported. If any cannot be imported, print a warning.

This is a heuristic check. Some packages may not share the same name as their importable module. Adjust the name mappings below as needed.

Example

check_requirements(Path("./requirements.txt"))

Prints warnings if imports fail, otherwise silent.
Source code in src/tnh_scholar/cli_tools/audio_transcribe/environment/env.py
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
def check_requirements(requirements_file: Path) -> None:
    """
    Check that all requirements listed in requirements.txt can be imported.
    If any cannot be imported, print a warning.

    This is a heuristic check. Some packages may not share the same name as their importable module.
    Adjust the name mappings below as needed.

    Example:
        >>> check_requirements(Path("./requirements.txt"))
        # Prints warnings if imports fail, otherwise silent.
    """
    # Map requirement names to their importable module names if they differ
    name_map = {
        "python-dotenv": "dotenv",
        "openai_whisper": "whisper",
        "protobuf": "google.protobuf",
        # Add other mappings if needed
    }

    # Parse requirements.txt to get a list of package names
    packages = []
    with requirements_file.open("r") as req_file:
        for line in req_file:
            line = line.strip()
            if not line or line.startswith("#"):
                continue
            # Each line generally looks like 'package==version'
            pkg_name = line.split("==")[0].strip()
            packages.append(pkg_name)

    # Try importing each package
    for pkg in packages:
        mod_name = name_map.get(pkg, pkg)
        try:
            __import__(mod_name)
        except ImportError:
            print(
                f"WARNING: Could not import '{mod_name}' from '{pkg}'. Check that it is correctly installed."
            )
validate
validate_inputs(is_download, yt_url, yt_url_list, audio_file, split, transcribe, chunk_dir, no_chunks, silence_boundaries, whisper_boundaries)

Validate the CLI inputs to ensure logical consistency given all the flags.

Conditions & Requirements: 1. At least one action (yt_download, split, transcribe) should be requested. Otherwise, nothing is done, so raise an error.

  1. If yt_download is True:
  2. Must specify either yt_process_url OR yt_process_url_list (not both, not none).

  3. If yt_download is False:

  4. If split is requested, we need a local audio file (since no download will occur).
  5. If transcribe is requested without split and without yt_download:

    • If no_chunks = False, we must have chunk_dir to read existing chunks.
    • If no_chunks = True, we must have a local audio file (direct transcription) or previously downloaded file (but since yt_download=False, previously downloaded file scenario doesn't apply here, so effectively we need local audio in that scenario).
  6. no_chunks flag:

  7. If no_chunks = True, we are doing direct transcription on entire audio without chunking.

    • Cannot use split if no_chunks = True. (Mutually exclusive)
    • chunk_dir is irrelevant if no_chunks = True; since we don't split into chunks, requiring a chunk_dir doesn't make sense. If provided, it's not useful, but let's allow it silently or raise an error for clarity. It's safer to raise an error to prevent user confusion.
  8. Boundaries flags (silence_boundaries, whisper_boundaries):

  9. These flags control how splitting is done.
  10. If split = False, these are irrelevant. Not necessarily an error, but could be a no-op. For robustness, raise an error if user specifies these without split, to avoid confusion.
  11. If split = True and no_chunks = True, that’s contradictory already, so no need for boundary logic there.
  12. If split = True, exactly one method should be chosen: If both silence_boundaries and whisper_boundaries are True simultaneously or both are False simultaneously, we need a clear default or raise an error. By the code snippet logic, whisper_boundaries is default True if not stated otherwise. To keep it robust:
    • If both are True, raise error.
    • If both are False, that means user explicitly turned them off or never turned on whisper. The code snippet sets whisper_boundaries True by default. If user sets it False somehow, we can then default to silence. Just ensure at run-time we have a deterministic method: If both are False, we can default to whisper or silence. Let's default to whisper if no flags given. However, given the code snippet, whisper_boundaries has a default of True. If the user sets whisper_boundaries to False and also does not set silence_boundaries, then no method is chosen. Let's then raise an error if both ended up False to avoid ambiguity.

Raises:

Type Description
ValueError

If the input arguments are not logically consistent.

Source code in src/tnh_scholar/cli_tools/audio_transcribe/validate.py
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
def validate_inputs(
    is_download: bool,
    yt_url: str | None,
    yt_url_list: Path | None,
    audio_file: Path | None,
    split: bool,
    transcribe: bool,
    chunk_dir: Path | None,
    no_chunks: bool,
    silence_boundaries: bool,
    whisper_boundaries: bool,
) -> None:
    """
    Validate the CLI inputs to ensure logical consistency given all the flags.

    Conditions & Requirements:
    1. At least one action (yt_download, split, transcribe) should be requested.
       Otherwise, nothing is done, so raise an error.

    2. If yt_download is True:
       - Must specify either yt_process_url OR yt_process_url_list (not both, not none).

    3. If yt_download is False:
       - If split is requested, we need a local audio file (since no download will occur).
       - If transcribe is requested without split and without yt_download:
         - If no_chunks = False, we must have chunk_dir to read existing chunks.
         - If no_chunks = True, we must have a local audio file (direct transcription) or previously downloaded file
           (but since yt_download=False, previously downloaded file scenario doesn't apply here,
           so effectively we need local audio in that scenario).

    4. no_chunks flag:
       - If no_chunks = True, we are doing direct transcription on entire audio without chunking.
         - Cannot use split if no_chunks = True. (Mutually exclusive)
         - chunk_dir is irrelevant if no_chunks = True; since we don't split into chunks,
           requiring a chunk_dir doesn't make sense. If provided, it's not useful, but let's allow it silently
           or raise an error for clarity. It's safer to raise an error to prevent user confusion.

    5. Boundaries flags (silence_boundaries, whisper_boundaries):
       - These flags control how splitting is done.
       - If split = False, these are irrelevant. Not necessarily an error, but could be a no-op.
         For robustness, raise an error if user specifies these without split, to avoid confusion.
       - If split = True and no_chunks = True, that’s contradictory already, so no need for boundary logic there.
       - If split = True, exactly one method should be chosen:
         If both silence_boundaries and whisper_boundaries are True simultaneously or both are False simultaneously,
         we need a clear default or raise an error. By the code snippet logic, whisper_boundaries is default True
         if not stated otherwise. To keep it robust:
           - If both are True, raise error.
           - If both are False, that means user explicitly turned them off or never turned on whisper.
             The code snippet sets whisper_boundaries True by default. If user sets it False somehow,
             we can then default to silence. Just ensure at run-time we have a deterministic method:
             If both are False, we can default to whisper or silence. Let's default to whisper if no flags given.
             However, given the code snippet, whisper_boundaries has a default of True.
             If the user sets whisper_boundaries to False and also does not set silence_boundaries,
             then no method is chosen. Let's then raise an error if both ended up False to avoid ambiguity.

    Raises:
        ValueError: If the input arguments are not logically consistent.
    """

    # 1. Check that we have at least one action
    if not is_download and not split and not transcribe:
        raise ValueError(
            "No actions requested. At least one of --yt_download, --split, --transcribe, or --full must be set."
        )

    # 2. Validate YouTube download logic
    if is_download:
        if yt_url and yt_url_list:
            raise ValueError(
                "Both --yt_process_url and --yt_process_url_list provided. Only one allowed."
            )
        if not yt_url and not yt_url_list:
            raise ValueError(
                "When --yt_download is specified, you must provide --yt_process_url or --yt_process_url_list."
            )

    # 3. Logic when no YouTube download:
    if not is_download:
        # If splitting but no download, need an audio file
        if split and audio_file is None:
            raise ValueError(
                "Splitting requested but no audio file provided and no YouTube download source available."
            )

        if transcribe and not split:
            if no_chunks:
                # Direct transcription, need an audio file
                if audio_file is None:
                    raise ValueError(
                        "Transcription requested with no_chunks=True but no audio file provided."
                    )
            elif chunk_dir is None:
                raise ValueError(
                    "Transcription requested without splitting or downloading and no_chunks=False. Must provide --chunk_dir with pre-split chunks."
                )

    # Check no_chunks scenario:
    # no_chunks and split are mutually exclusive
    # If transcribing but not splitting or downloading:
    # If no_chunks and chunk_dir provided, it doesn't make sense since we won't use chunks at all.
    # 4. no_chunks flag validation:
    # no_chunks=False, we need chunks from chunk_dir
    if no_chunks:
        if split:
            raise ValueError(
                "Cannot use --no_chunks and --split together. Choose one option."
            )
        if chunk_dir is not None:
            raise ValueError("Cannot specify --chunk_dir when --no_chunks is set.")

    # 5. Boundaries flags:
    # If splitting is not requested but boundaries flags are set, it's meaningless.
    # The code snippet defaults whisper_boundaries to True, so if user tries to turn it off and sets silence?
    # We'll require that boundaries only matter if split is True.
    if not split and (silence_boundaries or whisper_boundaries):
        raise ValueError(
            "Boundary detection flags given but splitting is not requested. Remove these flags or enable --split."
        )

    # If split is True, we must have a consistent boundary method:
    if split:
        # If both whisper and silence are somehow True:
        if silence_boundaries and whisper_boundaries:
            raise ValueError(
                "Cannot use both --silence_boundaries and --whisper_boundaries simultaneously."
            )

        # If both are False:
        # Given the original snippet, whisper_boundaries is True by default.
        # For the sake of robustness, let's say if user sets both off, we can't proceed:
        if not silence_boundaries and not whisper_boundaries:
            raise ValueError(
                "No boundary method selected for splitting. Enable either whisper or silence boundaries."
            )
version_check
logger = get_child_logger(__name__) module-attribute
YTDVersionChecker

Simple version checker for yt-dlp with robust version comparison.

This is a prototype implementation may need expansion in these areas: - Caching to prevent frequent PyPI calls - More comprehensive error handling for: - Missing/uninstalled packages - Network timeouts - JSON parsing errors - Invalid version strings - Environment detection (virtualenv, conda, system Python) - Configuration options for version pinning - Proxy support for network requests

Source code in src/tnh_scholar/cli_tools/audio_transcribe/version_check.py
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
class YTDVersionChecker:
    """
    Simple version checker for yt-dlp with robust version comparison.

    This is a prototype implementation may need expansion in these areas:
    - Caching to prevent frequent PyPI calls
    - More comprehensive error handling for:
        - Missing/uninstalled packages
        - Network timeouts
        - JSON parsing errors
        - Invalid version strings
    - Environment detection (virtualenv, conda, system Python)
    - Configuration options for version pinning
    - Proxy support for network requests
    """

    PYPI_URL = "https://pypi.org/pypi/yt-dlp/json"
    NETWORK_TIMEOUT = 5  # seconds

    def _get_installed_version(self) -> Version:
        """
        Get installed yt-dlp version.

        Returns:
            Version object representing installed version

        Raises:
            ImportError: If yt-dlp is not installed
            InvalidVersion: If installed version string is invalid
        """
        try:
            if version_str := str(importlib.metadata.version("yt-dlp")):
                return Version(version_str)
            else:
                raise InvalidVersion("yt-dlp version string is empty")
        except importlib.metadata.PackageNotFoundError as e:
            raise ImportError("yt-dlp is not installed") from e
        except InvalidVersion:
            raise

    def _get_latest_version(self) -> Version:
        """
        Get latest version from PyPI.

        Returns:
            Version object representing latest available version

        Raises:
            requests.RequestException: For any network-related errors
            InvalidVersion: If PyPI version string is invalid
            KeyError: If PyPI response JSON is malformed
        """
        try:
            response = requests.get(self.PYPI_URL, timeout=self.NETWORK_TIMEOUT)
            response.raise_for_status()
            version_str = response.json()["info"]["version"]
            return Version(version_str)
        except requests.RequestException as e:
            raise requests.RequestException(
                "Failed to fetch version from PyPI. Check network connection."
            ) from e

    def check_version(self) -> Tuple[bool, Version, Version]:
        """
        Check if yt-dlp needs updating.

        Returns:
            Tuple of (needs_update, installed_version, latest_version)

        Raises:
            ImportError: If yt-dlp is not installed
            requests.RequestException: For network-related errors
            InvalidVersion: If version strings are invalid
        """
        installed_version = self._get_installed_version()
        latest_version = self._get_latest_version()

        needs_update = installed_version < latest_version
        return needs_update, installed_version, latest_version
NETWORK_TIMEOUT = 5 class-attribute instance-attribute
PYPI_URL = 'https://pypi.org/pypi/yt-dlp/json' class-attribute instance-attribute
check_version()

Check if yt-dlp needs updating.

Returns:

Type Description
Tuple[bool, Version, Version]

Tuple of (needs_update, installed_version, latest_version)

Raises:

Type Description
ImportError

If yt-dlp is not installed

RequestException

For network-related errors

InvalidVersion

If version strings are invalid

Source code in src/tnh_scholar/cli_tools/audio_transcribe/version_check.py
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
def check_version(self) -> Tuple[bool, Version, Version]:
    """
    Check if yt-dlp needs updating.

    Returns:
        Tuple of (needs_update, installed_version, latest_version)

    Raises:
        ImportError: If yt-dlp is not installed
        requests.RequestException: For network-related errors
        InvalidVersion: If version strings are invalid
    """
    installed_version = self._get_installed_version()
    latest_version = self._get_latest_version()

    needs_update = installed_version < latest_version
    return needs_update, installed_version, latest_version
check_ytd_version()

Check if yt-dlp needs updating and log appropriate messages.

This function checks the installed version of yt-dlp against the latest version on PyPI and logs informational or error messages as appropriate. It handles network errors, missing packages, and version parsing issues gracefully.

The function does not raise exceptions but logs them using the application's logging system.

Source code in src/tnh_scholar/cli_tools/audio_transcribe/version_check.py
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
def check_ytd_version() -> None:
    """
    Check if yt-dlp needs updating and log appropriate messages.

    This function checks the installed version of yt-dlp against the latest version
    on PyPI and logs informational or error messages as appropriate. It handles
    network errors, missing packages, and version parsing issues gracefully.

    The function does not raise exceptions but logs them using the application's
    logging system.
    """
    checker = YTDVersionChecker()
    try:
        needs_update, current, latest = checker.check_version()
        if needs_update:
            logger.info(f"Update available: {current} -> {latest}")
            logger.info("Please run the appropriate upgrade in your environment.")
            logger.info("   For example: pip install --upgrade yt-dlp ")
        else:
            logger.info(f"yt-dlp is up to date (version {current})")

    except ImportError as e:
        logger.error(f"In yt-dlp version check: Package error: {e}")
    except requests.RequestException as e:
        logger.error(f"In yt-dlp version check: Network error: {e}")
    except InvalidVersion as e:
        logger.error(f"In yt-dlp version check: Version parsing error: {e}")
    except Exception as e:
        logger.error(f"In yt-dlp version check: Unexpected error: {e}")

nfmt

nfmt
main()

Entry point for the nfmt CLI tool.

Source code in src/tnh_scholar/cli_tools/nfmt/nfmt.py
24
25
26
def main():
    """Entry point for the nfmt CLI tool."""
    nfmt()
nfmt(input_file, output, spacing)

Normalize the number of newlines in a text file.

Source code in src/tnh_scholar/cli_tools/nfmt/nfmt.py
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
@click.command()
@click.argument("input_file", type=click.File("r"), default="-")
@click.option(
    "-o",
    "--output",
    type=click.File("w"),
    default="-",
    help="Output file (default: stdout)",
)
@click.option(
    "-s", "--spacing", default=2, help="Number of newlines between blocks (default: 2)"
)
def nfmt(input_file, output, spacing):
    """Normalize the number of newlines in a text file."""
    text = input_file.read()
    result = normalize_newlines(text, spacing)
    output.write(result)

sent_split

sent_split

Simple CLI tool for sentence splitting.

This module provides a command line interface for splitting text into sentences. Uses NLTK for robust sentence tokenization. Reads from stdin and writes to stdout by default, with optional file input/output.

ensure_nltk_data()

Ensure NLTK punkt tokenizer is available.

Source code in src/tnh_scholar/cli_tools/sent_split/sent_split.py
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
def ensure_nltk_data():
    """Ensure NLTK punkt tokenizer is available."""
    try:
        # Try to find the resource
        nltk.data.find('tokenizers/punkt')
    except LookupError:
        # If not found, try downloading
        try:
            nltk.download('punkt', quiet=True)
            # Verify download
            nltk.data.find('tokenizers/punkt')
        except Exception as e:
            raise RuntimeError(
                "Failed to download required NLTK data. "
                "Please run 'python -m nltk.downloader punkt' "
                f"to install manually. Error: {e}"
            ) from e
main()
Source code in src/tnh_scholar/cli_tools/sent_split/sent_split.py
75
76
def main():
    sent_split()
process_text(text, newline=True)

Split text into sentences using NLTK.

Source code in src/tnh_scholar/cli_tools/sent_split/sent_split.py
37
38
39
40
41
def process_text(text: str, newline: bool = True) -> str:
    """Split text into sentences using NLTK."""
    ensure_nltk_data()
    sentences = sent_tokenize(text)
    return "\n".join(sentences) if newline else " ".join(sentences)
sent_split(input_file, output, space)

Split text into sentences using NLTK's sentence tokenizer.

Reads from stdin if no input file is specified. Writes to stdout if no output file is specified.

Source code in src/tnh_scholar/cli_tools/sent_split/sent_split.py
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
@click.command()
@click.argument('input_file', type=click.File('r'), required=False)
@click.option('-o', '--output', type=click.File('w'),
              help='Output file (default: stdout)')
@click.option('-s', '--space', is_flag=True,
              help='Separate sentences with spaces instead of newlines')
def sent_split(input_file: Optional[TextIO],
               output: Optional[TextIO],
               space: bool) -> None:
    """Split text into sentences using NLTK's sentence tokenizer.

    Reads from stdin if no input file is specified.
    Writes to stdout if no output file is specified.
    """
    try:
        # Read from file or stdin
        input_text = input_file.read() if input_file else sys.stdin.read()

        # Process the text
        result = process_text(input_text, newline=not space)

        # Write to file or stdout
        output_file = output or sys.stdout
        output_file.write(result)

        if output:
            click.echo(f"Output written to: {output.name}")

    except Exception as e:
        click.echo(f"Error processing text: {e}", err=True)
        sys.exit(1)

tnh_fab

tnh_fab

TNH-FAB Command Line Interface

Part of the THICH NHAT HANH SCHOLAR (TNH_SCHOLAR) project. A rapid prototype implementation of the TNH-FAB command-line tool for Open AI based text processing. Provides core functionality for text punctuation, sectioning, translation, and processing.

DEFAULT_SECTION_PATTERN = 'default_section' module-attribute
logger = get_child_logger(__name__) module-attribute
pass_config = click.make_pass_decorator(TNHFabConfig, ensure=True) module-attribute
TNHFabConfig

Holds configuration for the TNH-FAB CLI tool.

Source code in src/tnh_scholar/cli_tools/tnh_fab/tnh_fab.py
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
class TNHFabConfig:
    """Holds configuration for the TNH-FAB CLI tool."""

    def __init__(self):
        self.verbose: bool = False
        self.debug: bool = False
        self.quiet: bool = False
        # Initialize pattern manager with directory set in .env file or default.

        load_dotenv()

        if pattern_path_name := os.getenv("TNH_PATTERN_DIR"):
            pattern_dir = Path(pattern_path_name)
            logger.debug(f"pattern dir: {pattern_path_name}")
        else:
            pattern_dir = TNH_DEFAULT_PATTERN_DIR

        pattern_dir.mkdir(parents=True, exist_ok=True)
        self.pattern_manager = PatternManager(pattern_dir)
debug = False instance-attribute
pattern_manager = PatternManager(pattern_dir) instance-attribute
quiet = False instance-attribute
verbose = False instance-attribute
__init__()
Source code in src/tnh_scholar/cli_tools/tnh_fab/tnh_fab.py
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
def __init__(self):
    self.verbose: bool = False
    self.debug: bool = False
    self.quiet: bool = False
    # Initialize pattern manager with directory set in .env file or default.

    load_dotenv()

    if pattern_path_name := os.getenv("TNH_PATTERN_DIR"):
        pattern_dir = Path(pattern_path_name)
        logger.debug(f"pattern dir: {pattern_path_name}")
    else:
        pattern_dir = TNH_DEFAULT_PATTERN_DIR

    pattern_dir.mkdir(parents=True, exist_ok=True)
    self.pattern_manager = PatternManager(pattern_dir)
get_pattern(pattern_manager, pattern_name)

Get pattern from the pattern manager.

Parameters:

Name Type Description Default
pattern_manager PatternManager

Initialized PatternManager instance

required
pattern_name str

Name of the pattern to load

required

Returns:

Name Type Description
Pattern Pattern

Loaded pattern object

Raises:

Type Description
ClickException

If pattern cannot be loaded

Source code in src/tnh_scholar/cli_tools/tnh_fab/tnh_fab.py
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
def get_pattern(pattern_manager: PatternManager, pattern_name: str) -> Pattern:
    """
    Get pattern from the pattern manager.

    Args:
        pattern_manager: Initialized PatternManager instance
        pattern_name: Name of the pattern to load

    Returns:
        Pattern: Loaded pattern object

    Raises:
        click.ClickException: If pattern cannot be loaded
    """
    try:
        return pattern_manager.load_pattern(pattern_name)
    except FileNotFoundError as e:
        raise click.ClickException(
            f"Pattern '{pattern_name}' not found in {pattern_manager.base_path}"
        ) from e
    except Exception as e:
        raise click.ClickException(f"Error loading pattern: {e}") from e
main()

Entry point for TNH-FAB CLI tool.

Source code in src/tnh_scholar/cli_tools/tnh_fab/tnh_fab.py
481
482
483
def main():
    """Entry point for TNH-FAB CLI tool."""
    tnh_fab()
process(config, input_file, pattern, section, paragraph, template)

Apply custom pattern-based processing to text with flexible structuring options.

This command provides flexible text processing using customizable patterns. It can process text either by sections (defined in a JSON file or auto-detected), by paragraphs, or can be used to process a text as a whole (this is the default). This is particularly useful for formatting, restructuring, or applying consistent transformations to text.

Examples:


# Process using a specific pattern
$ tnh-fab process -p format_xml input.txt


# Process using paragraph mode
$ tnh-fab process -p format_xml -g input.txt


# Process with custom sections
$ tnh-fab process -p format_xml -s sections.json input.txt


# Process with template values
$ tnh-fab process -p format_xml -t template.yaml input.txt

Processing Modes:


1. Single Input Mode (default)
    - Processes entire input.


2. Section Mode (-s):
    - Uses sections from JSON file if provided (-s)
    - If no section file is provided, sections are auto-generated.
    - Processes each section according to pattern


3. Paragraph Mode (-g):
    - Treats each line/paragraph as a separate unit
    - Useful for simpler processing tasks
    - More memory efficient for large files

 Notes: - Required pattern must exist in pattern directory - Template values can customize pattern behavior

Source code in src/tnh_scholar/cli_tools/tnh_fab/tnh_fab.py
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
@tnh_fab.command()
@click.argument(
    "input_file", type=click.Path(exists=True, path_type=Path), required=False
)
@click.option("-p", "--pattern", required=True, help="Pattern name for processing")
@click.option(
    "-s",
    "--section",
    type=click.Path(exists=True, path_type=Path),
    help="Process using sections from JSON file, or auto-generate if no file provided",
)
@click.option("-g", "--paragraph", is_flag=True, help="Process text by paragraphs")
@click.option(
    "-t",
    "--template",
    type=click.Path(exists=True, path_type=Path),
    help="YAML file containing template values",
)
@pass_config
def process(
    config: TNHFabConfig,
    input_file: Optional[Path],
    pattern: str,
    section: Optional[Path],
    paragraph: bool,
    template: Optional[Path],
):
    """Apply custom pattern-based processing to text with flexible structuring options.

    This command provides flexible text processing using customizable patterns. It can
    process text either by sections (defined in a JSON file or auto-detected), by
    paragraphs, or can be used to process a text as a whole (this is the default).
    This is particularly useful for formatting, restructuring, or applying
    consistent transformations to text.

    Examples:

        \b
        # Process using a specific pattern
        $ tnh-fab process -p format_xml input.txt

        \b
        # Process using paragraph mode
        $ tnh-fab process -p format_xml -g input.txt

        \b
        # Process with custom sections
        $ tnh-fab process -p format_xml -s sections.json input.txt

        \b
        # Process with template values
        $ tnh-fab process -p format_xml -t template.yaml input.txt


    Processing Modes:

        \b
        1. Single Input Mode (default)
            - Processes entire input.

        \b
        2. Section Mode (-s):
            - Uses sections from JSON file if provided (-s)
            - If no section file is provided, sections are auto-generated.
            - Processes each section according to pattern

        \b
        3. Paragraph Mode (-g):
            - Treats each line/paragraph as a separate unit
            - Useful for simpler processing tasks
            - More memory efficient for large files

    \b
    Notes:
        - Required pattern must exist in pattern directory
        - Template values can customize pattern behavior

    """
    text = read_input(click, input_file)  # type: ignore
    process_pattern = get_pattern(config.pattern_manager, pattern)

    template_dict: Dict[str, str] = {}

    if paragraph:
        result = process_text_by_paragraphs(
            text, template_dict, pattern=process_pattern
        )
        for processed in result:
            click.echo(processed)
    elif section is not None:  # Section mode (either file or auto-generate)
        if isinstance(section, Path):  # Section file provided
            sections_json = Path(section).read_text()
            text_obj = TextObject.model_validate_json(sections_json)

        else:  # Auto-generate sections
            default_section_pattern = get_pattern(
                config.pattern_manager, DEFAULT_SECTION_PATTERN
            )
            text_obj = find_sections(text, section_pattern=default_section_pattern)

        result = process_text_by_sections(
            text, text_obj, template_dict, pattern=process_pattern
        )
        for processed_section in result:
            click.echo(processed_section.processed_text)
    else:
        result = process_text(
            text, pattern=process_pattern, template_dict=template_dict
        )
        click.echo(result)
punctuate(config, input_file, language, style, review_count, pattern)

Add punctuation and structure to text based on language-specific rules.

This command processes input text to add or correct punctuation, spacing, and basic structural elements. It is particularly useful for texts that lack proper punctuation or need standardization.

Examples:


# Process a file using default settings
$ tnh-fab punctuate input.txt


# Process Vietnamese text with custom style
$ tnh-fab punctuate -l vi -y "Modern" input.txt


# Process from stdin with increased review passes
$ cat input.txt | tnh-fab punctuate -c 5
Source code in src/tnh_scholar/cli_tools/tnh_fab/tnh_fab.py
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
@tnh_fab.command()
@click.argument(
    "input_file", type=click.Path(exists=True, path_type=Path), required=False
)
@click.option(
    "-l",
    "--language",
    help="Source language code (e.g., 'en', 'vi'). Auto-detected if not specified.",
)
@click.option(
    "-y", "--style", default="APA", help="Punctuation style to apply (default: 'APA')"
)
@click.option(
    "-c",
    "--review-count",
    type=int,
    default=3,
    help="Number of review passes (default: 3)",
)
@click.option(
    "-p",
    "--pattern",
    default="default_punctuate",
    help="Pattern name for punctuation rules (default: 'default_punctuate')",
)
@pass_config
def punctuate(
    config: TNHFabConfig,
    input_file: Optional[Path],
    language: Optional[str],
    style: str,
    review_count: int,
    pattern: str,
):
    """Add punctuation and structure to text based on language-specific rules.

    This command processes input text to add or correct punctuation, spacing, and basic
    structural elements. It is particularly useful for texts that lack proper punctuation
    or need standardization.


    Examples:

        \b
        # Process a file using default settings
        $ tnh-fab punctuate input.txt

        \b
        # Process Vietnamese text with custom style
        $ tnh-fab punctuate -l vi -y "Modern" input.txt

        \b
        # Process from stdin with increased review passes
        $ cat input.txt | tnh-fab punctuate -c 5

    """
    text = read_input(click, input_file)  # type: ignore
    punctuate_pattern = get_pattern(config.pattern_manager, pattern)
    result = punctuate_text(
        text,
        source_language=language,
        punctuate_pattern=punctuate_pattern,
        template_dict={"style_convention": style, "review_count": review_count},
    )
    click.echo(result)
read_input(ctx, input_file)

Read input from file or stdin.

Source code in src/tnh_scholar/cli_tools/tnh_fab/tnh_fab.py
63
64
65
66
67
68
69
def read_input(ctx: Context, input_file: Optional[Path]) -> str:
    """Read input from file or stdin."""
    if input_file:
        return input_file.read_text()
    if not sys.stdin.isatty():
        return sys.stdin.read()
    ctx.fail("No input provided")
section(config, input_file, language, num_sections, review_count, pattern)

Analyze and divide text into logical sections based on content.

This command processes the input text to identify coherent sections based on content analysis. It generates a structured representation of the text with sections that maintain logical continuity. Each section includes metadata such as title and line range.

Examples:


# Auto-detect sections in a file
$ tnh-fab section input.txt


# Specify desired number of sections
$ tnh-fab section -n 5 input.txt


# Process Vietnamese text with custom pattern
$ tnh-fab section -l vi -p custom_section_pattern input.txt


# Section text from stdin with increased review
$ cat input.txt | tnh-fab section -c 5

 Output Format: JSON object containing: - language: Detected or specified language code - sections: Array of section objects, each with: - title: Section title in original language - start_line: Starting line number (inclusive) - end_line: Ending line number (inclusive)

Source code in src/tnh_scholar/cli_tools/tnh_fab/tnh_fab.py
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
@tnh_fab.command()
@click.argument(
    "input_file", type=click.Path(exists=True, path_type=Path), required=False
)
@click.option(
    "-l",
    "--language",
    help="Source language code (e.g., 'en', 'vi'). Auto-detected if not specified.",
)
@click.option(
    "-n",
    "--num-sections",
    type=int,
    help="Target number of sections (auto-calculated if not specified)",
)
@click.option(
    "-c",
    "--review-count",
    type=int,
    default=3,
    help="Number of review passes (default: 3)",
)
@click.option(
    "-p",
    "--pattern",
    default="default_section",
    help="Pattern name for section analysis (default: 'default_section')",
)
@pass_config
def section(
    config: TNHFabConfig,
    input_file: Optional[Path],
    language: Optional[str],
    num_sections: Optional[int],
    review_count: int,
    pattern: str,
):
    """Analyze and divide text into logical sections based on content.

    This command processes the input text to identify coherent sections based on content
    analysis. It generates a structured representation of the text with sections that
    maintain logical continuity. Each section includes metadata such as title and line
    range.

    Examples:

        \b
        # Auto-detect sections in a file
        $ tnh-fab section input.txt

        \b
        # Specify desired number of sections
        $ tnh-fab section -n 5 input.txt

        \b
        # Process Vietnamese text with custom pattern
        $ tnh-fab section -l vi -p custom_section_pattern input.txt

        \b
        # Section text from stdin with increased review
        $ cat input.txt | tnh-fab section -c 5

    \b
    Output Format:
        JSON object containing:
        - language: Detected or specified language code
        - sections: Array of section objects, each with:
            - title: Section title in original language
            - start_line: Starting line number (inclusive)
            - end_line: Ending line number (inclusive)
    """
    text = read_input(click, input_file)  # type: ignore
    section_pattern = get_pattern(config.pattern_manager, pattern)
    result = find_sections(
        text,
        source_language=language,
        section_pattern=section_pattern,
        section_count=num_sections,
        review_count=review_count,
    )
    # For prototype, just output the JSON representation
    click.echo(result.model_dump_json(indent=2))
tnh_fab(ctx, verbose, debug, quiet)

TNH-FAB: Thich Nhat Hanh Scholar Text processing command-line tool.

CORE COMMANDS: punctuate, section, translate, process

To Get help on any command and see its options:

tnh-fab [COMMAND] --help

Provides specialized processing for multi-lingual Dharma content.

Offers functionalities for punctuation, sectioning, line-based translation, and general text processing based on predefined patterns. Input text can be provided either via a file or standard input.

Source code in src/tnh_scholar/cli_tools/tnh_fab/tnh_fab.py
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
@click.group()
@click.option("-v", "--verbose", is_flag=True, help="Enable detailed logging. (NOT implemented)")
@click.option("--debug", is_flag=True, help="Enable debug output")
@click.option("--quiet", is_flag=True, help="Suppress all non-error output")
@click.pass_context
def tnh_fab(ctx: Context, verbose: bool, debug: bool, quiet: bool):
    """TNH-FAB: Thich Nhat Hanh Scholar Text processing command-line tool.

    CORE COMMANDS: punctuate, section, translate, process

    To Get help on any command and see its options:

    tnh-fab [COMMAND] --help

    Provides specialized processing for multi-lingual Dharma content.

    Offers functionalities for punctuation, sectioning, line-based translation,
    and general text processing based on predefined patterns.
    Input text can be provided either via a file or standard input.
    """        
    config = ctx.ensure_object(TNHFabConfig)

    if not check_openai_env():

        ctx.fail("Missing OpenAI Credentials.")

    config.verbose = verbose
    config.debug = debug
    config.quiet = quiet

    if not quiet:
        if debug:
            setup_logging(log_level=logging.DEBUG)
        else:
            setup_logging(log_level=logging.INFO)
translate(config, input_file, language, target, style, context_lines, segment_size, pattern)

Translate text while preserving line numbers and contextual understanding.

This command performs intelligent translation that maintains line number correspondence between source and translated text. It uses surrounding context to improve translation accuracy and consistency, particularly important for Buddhist texts where terminology and context are crucial.

Examples:


# Translate Vietnamese text to English
$ tnh-fab translate -l vi input.txt


# Translate to French with specific style
$ tnh-fab translate -l vi -r fr -y "Formal" input.txt


# Translate with increased context
$ tnh-fab translate --context-lines 5 input.txt


# Translate using custom segment size
$ tnh-fab translate --segment-size 10 input.txt

 Notes: - Line numbers are preserved in the output - Context lines are used to improve translation accuracy - Segment size affects processing speed and memory usage

Source code in src/tnh_scholar/cli_tools/tnh_fab/tnh_fab.py
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
@tnh_fab.command()
@click.argument(
    "input_file", type=click.Path(exists=True, path_type=Path), required=False
)
@click.option(
    "-l", "--language", help="Source language code. Auto-detected if not specified."
)
@click.option(
    "-r", "--target", default="en", help="Target language code (default: 'en')"
)
@click.option(
    "-y", "--style", help="Translation style (e.g., 'American Dharma Teaching')"
)
@click.option(
    "--context-lines",
    type=int,
    default=3,
    help="Number of context lines to consider (default: 3)",
)
@click.option(
    "--segment-size",
    type=int,
    help="Lines per translation segment (auto-calculated if not specified)",
)
@click.option(
    "-p",
    "--pattern",
    default="default_line_translation",
    help="Pattern name for translation (default: 'default_line_translation')",
)
@pass_config
def translate(
    config: TNHFabConfig,
    input_file: Optional[Path],
    language: Optional[str],
    target: str,
    style: Optional[str],
    context_lines: int,
    segment_size: Optional[int],
    pattern: str,
):
    """Translate text while preserving line numbers and contextual understanding.

    This command performs intelligent translation that maintains line number correspondence
    between source and translated text. It uses surrounding context to improve translation
    accuracy and consistency, particularly important for Buddhist texts where terminology
    and context are crucial.

    Examples:

        \b
        # Translate Vietnamese text to English
        $ tnh-fab translate -l vi input.txt

        \b
        # Translate to French with specific style
        $ tnh-fab translate -l vi -r fr -y "Formal" input.txt

        \b
        # Translate with increased context
        $ tnh-fab translate --context-lines 5 input.txt

        \b
        # Translate using custom segment size
        $ tnh-fab translate --segment-size 10 input.txt

    \b
    Notes:
        - Line numbers are preserved in the output
        - Context lines are used to improve translation accuracy
        - Segment size affects processing speed and memory usage
    """
    text = read_input(click, input_file)  # type: ignore
    translation_pattern = get_pattern(config.pattern_manager, pattern)
    result = translate_text_by_lines(
        text,
        source_language=language,
        target_language=target,
        pattern=translation_pattern,
        style=style,
        context_lines=context_lines,
        segment_size=segment_size,
    )
    click.echo(result)

tnh_setup

tnh_setup
OPENAI_ENV_HELP_MSG = "\n>>>>>>>>>> OpenAI API key not found in environment. <<<<<<<<<\n\nFor AI processing with TNH-scholar:\n\n1. Get an API key from https://platform.openai.com/api-keys\n2. Set the OPENAI_API_KEY environment variable:\n\n export OPENAI_API_KEY='your-api-key-here' # Linux/Mac\n set OPENAI_API_KEY=your-api-key-here # Windows\n\nFor OpenAI API access help: https://platform.openai.com/\n\n>>>>>>>>>>>>>>>>>>>>>>>>>>> -- <<<<<<<<<<<<<<<<<<<<<<<<<<<<<<\n" module-attribute
PATTERNS_URL = 'https://github.com/aaronksolomon/patterns/archive/main.zip' module-attribute
create_config_dirs()

Create required configuration directories.

Source code in src/tnh_scholar/cli_tools/tnh_setup/tnh_setup.py
39
40
41
42
43
def create_config_dirs():
    """Create required configuration directories."""
    TNH_CONFIG_DIR.mkdir(parents=True, exist_ok=True)
    TNH_LOG_DIR.mkdir(exist_ok=True)
    TNH_DEFAULT_PATTERN_DIR.mkdir(exist_ok=True)
download_patterns()

Download and extract pattern files from GitHub.

Source code in src/tnh_scholar/cli_tools/tnh_setup/tnh_setup.py
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
def download_patterns() -> bool:
    """Download and extract pattern files from GitHub."""
    try:
        response = requests.get(PATTERNS_URL)
        response.raise_for_status()

        with zipfile.ZipFile(io.BytesIO(response.content)) as zip_ref:
            root_dir = zip_ref.filelist[0].filename.split('/')[0]

            for zip_info in zip_ref.filelist:
                if zip_info.filename.endswith('.md'):
                    rel_path = Path(zip_info.filename).relative_to(root_dir)
                    target_path = TNH_DEFAULT_PATTERN_DIR / rel_path

                    target_path.parent.mkdir(parents=True, exist_ok=True)

                    with zip_ref.open(zip_info) as source, open(target_path, 'wb') as target:
                        target.write(source.read())
        return True

    except Exception as e:
        click.echo(f"Pattern download failed: {e}", err=True)
        return False
main()

Entry point for setup CLI tool.

Source code in src/tnh_scholar/cli_tools/tnh_setup/tnh_setup.py
97
98
99
def main():
    """Entry point for setup CLI tool."""
    tnh_setup()
tnh_setup(skip_env, skip_patterns)

Set up TNH Scholar configuration.

Source code in src/tnh_scholar/cli_tools/tnh_setup/tnh_setup.py
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
@click.command()
@click.option('--skip-env', is_flag=True, help='Skip API key setup')
@click.option('--skip-patterns', is_flag=True, help='Skip pattern download')
def tnh_setup(skip_env: bool, skip_patterns: bool):
    """Set up TNH Scholar configuration."""
    click.echo("Setting up TNH Scholar...")

    # Create config directories
    create_config_dirs()
    click.echo(f"Created config directory: {TNH_CONFIG_DIR}")

    # Pattern download
    if not skip_patterns and click.confirm(
                "\nDownload pattern (markdown text) files from GitHub?\n"
                f"Source: {PATTERNS_URL}\n"
                f"Target: {TNH_DEFAULT_PATTERN_DIR}"
            ):
        if download_patterns():
            click.echo("Pattern files downloaded successfully")
        else:
            click.echo("Pattern download failed", err=True)

    # Environment test:
    if not skip_env:
        load_dotenv()  # for development
        if not check_openai_env(output=False):
            print(OPENAI_ENV_HELP_MSG)

token_count

token_count
main()

Entry point for the token-count CLI tool.

Source code in src/tnh_scholar/cli_tools/token_count/token_count.py
15
16
17
def main():
    """Entry point for the token-count CLI tool."""
    token_count_cli()
token_count_cli(input_file)

Return the Open AI API token count of a text file. Based on gpt-4o.

Source code in src/tnh_scholar/cli_tools/token_count/token_count.py
 6
 7
 8
 9
10
11
12
@click.command()
@click.argument("input_file", type=click.File("r"), default="-")
def token_count_cli(input_file):
    """Return the Open AI API token count of a text file. Based on gpt-4o."""
    text = input_file.read()
    result = token_count(text)
    click.echo(result)

ytt_fetch

ytt_fetch

Simple CLI tool for retrieving video transcripts.

This module provides a command line interface for downloading video transcripts in specified languages. It uses yt-dlp for video info extraction.

main()
Source code in src/tnh_scholar/cli_tools/ytt_fetch/ytt_fetch.py
65
66
def main():
    ytt_fetch()
ytt_fetch(url, lang, output)

Youtube Transcript Fetch: Retrieve and save transcript for a Youtube video using yt-dlp.

Source code in src/tnh_scholar/cli_tools/ytt_fetch/ytt_fetch.py
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
@click.command()
@click.argument("url")
@click.option(
    "-l", "--lang", default="en", help="Language code for transcript (default: en)"
)
@click.option(
    "-o",
    "--output",
    type=click.Path(),
    help="Save transcript text to file instead of printing",
)
def ytt_fetch(url: str, lang: str, output: Optional[str]) -> None:
    """
    Youtube Transcript Fetch: Retrieve and save transcript for a Youtube video using yt-dlp.
    """

    try:
        transcript_text = get_transcript(url, lang)

    except TranscriptNotFoundError as e:
        click.echo(e, err=True)
        sys.exit(1)
    except yt_dlp.utils.DownloadError as e:
        click.echo(f"Failed to extract video transcript: {e}", err=True)
        sys.exit(1)

    try:
        if output:
            output_path = Path(output)
            write_text_to_file(output_path, transcript_text, overwrite=True)
            click.echo(f"Transcript written to: {output_path}")
        else:
            click.echo(transcript_text)

    except FileNotFoundError as e:
        click.echo(f"File not found error: {e}", err=True)
        sys.exit(1)
    except (IOError, OSError) as e:
        click.echo(f"Error writing transcript to file: {e}", err=True)
        sys.exit(1)
    except TypeError as e:
        click.echo(f"Unexpected type error: {e}", err=True)
        sys.exit(1)

dev_tools

generate_tree

ignore_list = ['__pycache__', '*.pyc', '*.pyo', '*.pyd', '.git*', '.pytest_cache', '*.egg-info', 'dist', 'build', 'data', 'processed_data', 'sandbox', 'patterns', '.vscode', 'tmp', 'site'] module-attribute
ignore_str = '|'.join(ignore_list) module-attribute
output_file = sys.argv[2] if len(sys.argv) > 2 else 'project_directory_tree.txt' module-attribute
root_dir = sys.argv[1] if len(sys.argv) > 1 else '.' module-attribute
generate_tree(root_dir='.', output_file='project_directory_tree.txt')
Source code in src/tnh_scholar/dev_tools/generate_tree.py
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
def generate_tree(root_dir='.', output_file='project_directory_tree.txt'):
    root = Path(root_dir)
    output = Path(output_file)

    try:
        # Try using tree command first
        subprocess.run(["tree", "-I", 
            ignore_str,
            "-o", str(output)], check=True)
    except (subprocess.CalledProcessError, FileNotFoundError):
        # Fallback to pathlib implementation
        with open(output, "w") as f:
            f.write(".\n")
            for path in sorted(root.rglob("*")):
                if any(p.startswith(".") for p in path.parts):
                    continue
                if path.name in {"__pycache__", "*.pyc", "*.pyo", "*.pyd"}:
                    continue
                rel_path = path.relative_to(root)
                f.write(f"{'    ' * (len(path.parts)-1)}├── {path.name}\n")

journal_processing

journal_process

BATCH_RETRY_DELAY = 5 module-attribute
MAX_BATCH_RETRIES = 40 module-attribute
MAX_TOKEN_LIMIT = 60000 module-attribute
journal_schema = {'type': 'object', 'properties': {'journal_summary': {'type': 'string'}, 'sections': {'type': 'array', 'items': {'type': 'object', 'properties': {'title_vi': {'type': 'string'}, 'title_en': {'type': 'string'}, 'author': {'type': ['string', 'null']}, 'summary': {'type': 'string'}, 'keywords': {'type': 'array', 'items': {'type': 'string'}}, 'start_page': {'type': 'integer', 'minimum': 1}, 'end_page': {'type': 'integer', 'minimum': 1}}, 'required': ['title_vi', 'title_en', 'summary', 'keywords', 'start_page', 'end_page']}}}, 'required': ['journal_summary', 'sections']} module-attribute
logger = logging.getLogger('journal_process') module-attribute
batch_section(input_xml_path, batch_jsonl, system_message, journal_name)

Splits the journal content into sections using GPT, with retries for both starting and completing the batch.

Parameters:

Name Type Description Default
input_xml_path str

Path to the input XML file.

required
output_json_path str

Path to save validated metadata JSON.

required
raw_output_path str

Path to save the raw batch results.

required
journal_name str

Name of the journal being processed.

required
max_retries int

Maximum number of retries for batch processing.

required
retry_delay int

Delay in seconds between retries.

required

Returns:

Name Type Description
str

the result of the batch sectioning process as a serialized json object.

Source code in src/tnh_scholar/journal_processing/journal_process.py
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
def batch_section(
    input_xml_path: Path, batch_jsonl: Path, system_message, journal_name
):
    """
    Splits the journal content into sections using GPT, with retries for both starting and completing the batch.

    Args:
        input_xml_path (str): Path to the input XML file.
        output_json_path (str): Path to save validated metadata JSON.
        raw_output_path (str): Path to save the raw batch results.
        journal_name (str): Name of the journal being processed.
        max_retries (int): Maximum number of retries for batch processing.
        retry_delay (int): Delay in seconds between retries.

    Returns:
        str: the result of the batch sectioning process as a serialized json object.
    """
    try:
        logger.info(
            f"Starting sectioning batch for {journal_name} with file:\n\t{input_xml_path}"
        )
        # Load journal content
        journal_pages = get_text_from_file(input_xml_path)

        # Create GPT messages for sectioning
        user_message_wrapper = lambda text: f"{text}"
        messages = generate_messages(
            system_message, user_message_wrapper, [journal_pages]
        )

        # Create JSONL file for batch processing
        jsonl_file = create_jsonl_file_for_batch(messages, batch_jsonl, json_mode=True)

    except Exception as e:
        logger.error(
            f"Failed to initialize batch sectioning data for journal '{journal_name}'.",
            extra={"input_xml_path": input_xml_path},
            exc_info=True,
        )
        raise RuntimeError(
            f"Error initializing batch sectioning data for journal '{journal_name}'."
        ) from e

    response = start_batch_with_retries(
        jsonl_file,
        description=f"Batch for sectioning journal: {journal_name} | input file: {input_xml_path}",
    )

    if response:
        json_result = response[
            0
        ]  # should return json, just one batch so first response
        # Log success and return output json
        logger.info(
            f"Successfully batch sectioned journal '{journal_name}' with input file: {input_xml_path}."
        )
        return json_result
    else:
        logger.error("Section batch failed to get response.")
        return ""
batch_translate(input_xml_path, batch_json_path, metadata_path, system_message, journal_name)

Translates the journal sections using the GPT model. Saves the translated content back to XML.

Parameters:

Name Type Description Default
input_xml_path str

Path to the input XML file.

required
metadata_path str

Path to the metadata JSON file.

required
journal_name str

Name of the journal.

required
xml_output_path str

Path to save the translated XML.

required
max_retries int

Maximum number of retries for batch operations.

required
retry_delay int

Delay in seconds between retries.

required

Returns:

Name Type Description
bool

True if the process succeeds, False otherwise.

Source code in src/tnh_scholar/journal_processing/journal_process.py
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
def batch_translate(
    input_xml_path: Path,
    batch_json_path: Path,
    metadata_path: Path,
    system_message,
    journal_name: str,
):
    """
    Translates the journal sections using the GPT model.
    Saves the translated content back to XML.

    Args:
        input_xml_path (str): Path to the input XML file.
        metadata_path (str): Path to the metadata JSON file.
        journal_name (str): Name of the journal.
        xml_output_path (str): Path to save the translated XML.
        max_retries (int): Maximum number of retries for batch operations.
        retry_delay (int): Delay in seconds between retries.

    Returns:
        bool: True if the process succeeds, False otherwise.
    """
    logger.info(
        f"Starting translation batch for journal '{journal_name}':\n\twith file: {input_xml_path}\n\tmetadata: {metadata_path}"
    )

    # Data initialization:
    try:
        # load metadata
        serial_json = get_text_from_file(metadata_path)

        section_metadata = deserialize_json(serial_json)
        if not section_metadata:
            raise RuntimeError(f"Metadata could not be loaded from {metadata_path}.")

        # Extract page groups and split XML content
        page_groups = extract_page_groups_from_metadata(section_metadata)
        xml_content = get_text_from_file(input_xml_path)
        section_contents = split_xml_on_pagebreaks(xml_content, page_groups)

        if section_contents:
            logger.debug(f"section_contents[0]:\n{section_contents[0]}")
        else:
            logger.error("No sectin contents.")

    except Exception as e:
        logger.error(
            f"Failed to initialize data for translation batching for journal '{journal_name}'.",
            exc_info=True,
        )
        raise RuntimeError(
            f"Error during data initialization for journal '{journal_name}'."
        ) from e

    translation_data = translate_sections(
        batch_json_path,
        system_message,
        section_contents,
        section_metadata,
        journal_name,
    )
    return translation_data
deserialize_json(serialized_data)

Converts a serialized JSON string into a Python dictionary.

Parameters:

Name Type Description Default
serialized_data str

The JSON string to deserialize.

required

Returns:

Name Type Description
dict

The deserialized Python dictionary.

Source code in src/tnh_scholar/journal_processing/journal_process.py
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
def deserialize_json(serialized_data: str):
    """
    Converts a serialized JSON string into a Python dictionary.

    Args:
        serialized_data (str): The JSON string to deserialize.

    Returns:
        dict: The deserialized Python dictionary.
    """
    if not isinstance(serialized_data, str):
        logger.error(
            f"String input required for deserialize_json. Received: {type(serialized_data)}"
        )
        raise ValueError("String input required.")

    try:
        # Convert the JSON string into a dictionary
        return json.loads(serialized_data)
    except json.JSONDecodeError as e:
        logger.error(f"Failed to deserialize JSON: {e}")
        raise
extract_page_groups_from_metadata(metadata)

Extracts page groups from the section metadata for use with split_xml_pages.

Parameters:

Name Type Description Default
metadata dict

The section metadata containing sections with start and end pages.

required

Returns:

Type Description

List[Tuple[int, int]]: A list of tuples, each representing a page range (start_page, end_page).

Source code in src/tnh_scholar/journal_processing/journal_process.py
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
def extract_page_groups_from_metadata(metadata):
    """
    Extracts page groups from the section metadata for use with `split_xml_pages`.

    Parameters:
        metadata (dict): The section metadata containing sections with start and end pages.

    Returns:
        List[Tuple[int, int]]: A list of tuples, each representing a page range (start_page, end_page).
    """
    page_groups = []

    # Ensure metadata contains sections
    if "sections" not in metadata or not isinstance(metadata["sections"], list):
        raise ValueError(
            "Metadata does not contain a valid 'sections' key with a list of sections."
        )

    for section in metadata["sections"]:
        try:
            # Extract start and end pages
            start_page = section.get("start_page")
            end_page = section.get("end_page")

            # Ensure both start_page and end_page are integers
            if not isinstance(start_page, int) or not isinstance(end_page, int):
                raise ValueError(f"Invalid page range in section: {section}")

            # Add the tuple to the page groups list
            page_groups.append((start_page, end_page))

        except KeyError as e:
            print(f"Missing key in section metadata: {e}")
        except ValueError as e:
            print(f"Error processing section metadata: {e}")

    logger.debug(f"page groups found: {page_groups}")

    return page_groups
generate_all_batches(processed_document_dir, system_message, user_wrap_function, file_regex='.*\\.xml')

Generate cleaning batches for all journals in the specified directory.

Parameters:

Name Type Description Default
processed_journals_dir str

Path to the directory containing processed journal data.

required
system_message str

System message template for batch processing.

required
user_wrap_function callable

Function to wrap user input for processing pages.

required
file_regex str

Regex pattern to identify target files (default: ".*.xml").

'.*\\.xml'

Returns:

Type Description

None

Source code in src/tnh_scholar/journal_processing/journal_process.py
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
def generate_all_batches(
    processed_document_dir: str,
    system_message: str,
    user_wrap_function,
    file_regex: str = r".*\.xml",
):
    """
    Generate cleaning batches for all journals in the specified directory.

    Parameters:
        processed_journals_dir (str): Path to the directory containing processed journal data.
        system_message (str): System message template for batch processing.
        user_wrap_function (callable): Function to wrap user input for processing pages.
        file_regex (str): Regex pattern to identify target files (default: ".*\\.xml").

    Returns:
        None
    """
    logger = logging.getLogger(__name__)
    document_dir = Path(processed_document_dir)
    regex = re.compile(file_regex)

    for journal_file in document_dir.iterdir():
        if journal_file.is_file() and regex.search(journal_file.name):
            try:
                # Derive output file path
                output_file = journal_file.with_suffix(".jsonl")
                logger.info(f"Generating batch for {journal_file}...")

                # Call single batch function
                generate_single_oa_batch_from_pages(
                    input_xml_file=str(journal_file),
                    output_file=str(output_file),
                    system_message=system_message,
                    user_wrap_function=user_wrap_function,
                )
            except Exception as e:
                logger.error(f"Failed to process {journal_file}: {e}")
                continue

    logger.info("Batch generation completed.")
generate_clean_batch(input_xml_file, output_file, system_message, user_wrap_function)

Generate a batch file for the OpenAI (OA) API using a single input XML file.

Parameters:

Name Type Description Default
batch_file str

Full path to the input XML file to process.

required
output_file str

Full path to the output batch JSONL file.

required
system_message str

System message template for batch processing.

required
user_wrap_function callable

Function to wrap user input for processing pages.

required

Returns:

Name Type Description
str

Path to the created batch file.

Raises:

Type Description
Exception

If an error occurs during file processing.

Source code in src/tnh_scholar/journal_processing/journal_process.py
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
def generate_clean_batch(
    input_xml_file: str, output_file: str, system_message: str, user_wrap_function
):
    """
    Generate a batch file for the OpenAI (OA) API using a single input XML file.

    Parameters:
        batch_file (str): Full path to the input XML file to process.
        output_file (str): Full path to the output batch JSONL file.
        system_message (str): System message template for batch processing.
        user_wrap_function (callable): Function to wrap user input for processing pages.

    Returns:
        str: Path to the created batch file.

    Raises:
        Exception: If an error occurs during file processing.
    """

    try:
        # Read the OCR text from the batch file
        text = get_text_from_file(input_xml_file)
        logger.info(f"Processing file: {input_xml_file}")

        # Split the text into pages for processing
        pages = split_xml_on_pagebreaks(text)
        pages = wrap_all_lines(pages)  # wrap lines with brackets.
        if not pages:
            raise ValueError(f"No pages found in XML file: {input_xml_file}")
        logger.info(f"Found {len(pages)} pages in {input_xml_file}.")

        max_tokens = [_get_max_tokens_for_clean(page) for page in pages]

        # Generate messages for the pages
        batch_message_seq = generate_messages(system_message, user_wrap_function, pages)

        # Save the batch file
        create_jsonl_file_for_batch(
            batch_message_seq, output_file, max_token_list=max_tokens
        )
        logger.info(f"Batch file created successfully: {output_file}")

        return output_file

    except FileNotFoundError:
        logger.error("File not found.")
        raise
    except ValueError as e:
        logger.error(f"Value error: {e}")
        raise
    except Exception as e:
        logger.error(f"Unexpected error while processing {input_xml_file}: {e}")
        raise
generate_single_oa_batch_from_pages(input_xml_file, output_file, system_message, user_wrap_function)

*** Depricated *** Generate a batch file for the OpenAI (OA) API using a single input XML file.

Parameters:

Name Type Description Default
batch_file str

Full path to the input XML file to process.

required
output_file str

Full path to the output batch JSONL file.

required
system_message str

System message template for batch processing.

required
user_wrap_function callable

Function to wrap user input for processing pages.

required

Returns:

Name Type Description
str

Path to the created batch file.

Raises:

Type Description
Exception

If an error occurs during file processing.

Source code in src/tnh_scholar/journal_processing/journal_process.py
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
def generate_single_oa_batch_from_pages(
    input_xml_file: str,
    output_file: str,
    system_message: str,
    user_wrap_function,
):
    """
    *** Depricated ***
    Generate a batch file for the OpenAI (OA) API using a single input XML file.

    Parameters:
        batch_file (str): Full path to the input XML file to process.
        output_file (str): Full path to the output batch JSONL file.
        system_message (str): System message template for batch processing.
        user_wrap_function (callable): Function to wrap user input for processing pages.

    Returns:
        str: Path to the created batch file.

    Raises:
        Exception: If an error occurs during file processing.
    """
    logger = logging.getLogger(__name__)

    try:
        # Read the OCR text from the batch file
        text = get_text_from_file(input_xml_file)
        logger.info(f"Processing file: {input_xml_file}")

        # Split the text into pages for processing
        pages = split_xml_pages(text)
        if not pages:
            raise ValueError(f"No pages found in XML file: {input_xml_file}")
        logger.info(f"Found {len(pages)} pages in {input_xml_file}.")

        # Generate messages for the pages
        batch_message_seq = generate_messages(system_message, user_wrap_function, pages)

        # Save the batch file
        create_jsonl_file_for_batch(batch_message_seq, output_file)
        logger.info(f"Batch file created successfully: {output_file}")

        return output_file

    except FileNotFoundError:
        logger.error(f"File not found: {input_xml_file}")
        raise
    except ValueError as e:
        logger.error(f"Value error: {e}")
        raise
    except Exception as e:
        logger.error(f"Unexpected error while processing {input_xml_file}: {e}")
        raise
save_cleaned_data(cleaned_xml_path, cleaned_wrapped_pages, journal_name)
Source code in src/tnh_scholar/journal_processing/journal_process.py
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
def save_cleaned_data(
    cleaned_xml_path: Path, cleaned_wrapped_pages: List[str], journal_name
):
    try:
        logger.info(f"Saving cleaned content to XML for journal '{journal_name}'.")
        cleaned_wrapped_pages = unwrap_all_lines(cleaned_wrapped_pages)
        save_pages_to_xml(cleaned_xml_path, cleaned_wrapped_pages, overwrite=True)
        logger.info(f"Cleaned journal saved successfully to:\n\t{cleaned_xml_path}")
    except Exception as e:
        logger.error(
            f"Failed to save cleaned data for journal '{journal_name}'.",
            extra={"cleaned_xml_path": cleaned_xml_path},
            exc_info=True,
        )
        raise RuntimeError(
            f"Failed to save cleaned data for journal '{journal_name}'."
        ) from e
save_sectioning_data(output_json_path, raw_output_path, serial_json, journal_name)
Source code in src/tnh_scholar/journal_processing/journal_process.py
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
def save_sectioning_data(
    output_json_path: Path, raw_output_path: Path, serial_json: str, journal_name
):
    try:
        raw_output_path.write_text(serial_json, encoding="utf-8")
    except Exception as e:
        logger.error(
            f"Failed to write raw response file for journal '{journal_name}'.",
            extra={"raw_output_path": raw_output_path},
            exc_info=True,
        )
        raise RuntimeError(
            f"Failed to write raw response file for journal '{journal_name}'."
        ) from e

    # Validate and save metadata
    try:
        valid = validate_and_save_metadata(
            output_json_path, serial_json, journal_schema
        )
        if not valid:
            raise RuntimeError(
                f"Validation failed for metadata of journal '{journal_name}'."
            )
    except Exception as e:
        logger.error(
            f"Error occurred while validating and saving metadata for journal '{journal_name}'.",
            extra={"output_json_path": output_json_path},
            exc_info=True,
        )
        raise RuntimeError(f"Validation error for journal '{journal_name}'.") from e

    return output_json_path
save_translation_data(xml_output_path, translation_data, journal_name)
Source code in src/tnh_scholar/journal_processing/journal_process.py
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
def save_translation_data(xml_output_path: Path, translation_data, journal_name):
    # Save translated content back to XML
    try:
        logger.info(f"Saving translated content to XML for journal '{journal_name}'.")
        join_xml_data_to_doc(xml_output_path, translation_data, overwrite=True)
        logger.info(f"Translated journal saved successfully to:\n\t{xml_output_path}")

    except Exception as e:
        logger.error(
            f"Failed to save translation data for journal '{journal_name}'.",
            extra={"xml_output_path": xml_output_path},
            exc_info=True,
        )
        raise RuntimeError(
            f"Failed to save translation data for journal '{journal_name}'."
        ) from e
send_data_for_tx_batch(batch_jsonl_path, section_data_to_send, system_message, max_token_list, journal_name, immediate=False)

Sends data for translation batch or immediate processing.

Parameters:

Name Type Description Default
batch_jsonl_path Path

Path for the JSONL file to save batch data.

required
section_data_to_send List

List of section data to translate.

required
system_message str

System message for the translation process.

required
max_token_list List

List of max tokens for each section.

required
journal_name str

Name of the journal being processed.

required
immediate bool

If True, run immediate chat processing instead of batch.

False

Returns:

Name Type Description
List

Translated data from the batch or immediate process.

Source code in src/tnh_scholar/journal_processing/journal_process.py
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
def send_data_for_tx_batch(
    batch_jsonl_path: Path,
    section_data_to_send: List,
    system_message,
    max_token_list: List,
    journal_name,
    immediate=False,
):
    """
    Sends data for translation batch or immediate processing.

    Args:
        batch_jsonl_path (Path): Path for the JSONL file to save batch data.
        section_data_to_send (List): List of section data to translate.
        system_message (str): System message for the translation process.
        max_token_list (List): List of max tokens for each section.
        journal_name (str): Name of the journal being processed.
        immediate (bool): If True, run immediate chat processing instead of batch.

    Returns:
        List: Translated data from the batch or immediate process.
    """
    try:
        # Generate all messages using the generate_messages function
        user_message_wrapper = (
            lambda section_info: f"Translate this section with title '{section_info.title}':\n{section_info.content}"
        )
        messages = generate_messages(
            system_message, user_message_wrapper, section_data_to_send
        )

        if immediate:
            logger.info(f"Running immediate chat process for journal '{journal_name}'.")
            translated_data = []
            for i, message in enumerate(messages):
                max_tokens = max_token_list[i]
                response = run_immediate_chat_process(message, max_tokens=max_tokens)
                translated_data.append(response)
            logger.info(
                f"Immediate translation completed for journal '{journal_name}'."
            )
            return translated_data
        else:
            logger.info(f"Running batch processing for journal '{journal_name}'.")
            # Create batch file for batch processing
            jsonl_file = create_jsonl_file_for_batch(
                messages, batch_jsonl_path, max_token_list=max_token_list
            )
            if not jsonl_file:
                raise RuntimeError("Failed to create JSONL file for translation batch.")

            # Process batch and return the result
            translation_data = start_batch_with_retries(
                jsonl_file,
                description=f"Batch for translating journal '{journal_name}'",
            )
            logger.info(f"Batch translation completed for journal '{journal_name}'.")
            return translation_data

    except Exception as e:
        logger.error(
            f"Error during translation processing for journal '{journal_name}'.",
            exc_info=True,
        )
        raise RuntimeError("Error in translation process.") from e
setup_logger(log_file_path)

Configures the logger to write to a log file and the console. Adds a custom "PRIORITY_INFO" logging level for important messages.

Source code in src/tnh_scholar/journal_processing/journal_process.py
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
def setup_logger(log_file_path):
    """
    Configures the logger to write to a log file and the console.
    Adds a custom "PRIORITY_INFO" logging level for important messages.
    """
    # Remove existing handlers
    for handler in logging.root.handlers[:]:
        logging.root.removeHandler(handler)

    logging.basicConfig(
        level=logging.DEBUG,
        format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",  # Include logger name
        handlers=[
            logging.FileHandler(log_file_path, encoding="utf-8"),
            logging.StreamHandler(),  # Optional: to log to the console as well
        ],
    )

    # Suppress DEBUG/INFO logs for specific noisy modules
    modules_to_suppress = ["httpx", "httpcore", "urllib3", "openai", "google"]
    for module in modules_to_suppress:
        logger = logging.getLogger(module)
        logger.setLevel(logging.WARNING)  # Suppress DEBUG and INFO logs

    # Add a custom "PRIORITY_INFO" level
    PRIORITY_INFO_LEVEL = 25  # Between INFO (20) and WARNING (30)
    logging.addLevelName(PRIORITY_INFO_LEVEL, "PRIORITY_INFO")

    def priority_info(self, message, *args, **kwargs):
        if self.isEnabledFor(PRIORITY_INFO_LEVEL):
            self._log(PRIORITY_INFO_LEVEL, f"\033[93m{message}\033[0m", args, **kwargs)

    logging.Logger.priority_info = priority_info

    return logging.getLogger(__name__)
translate_sections(batch_jsonl_path, system_message, section_contents, section_metadata, journal_name, immediate=False)

build up sections in batches to translate

Source code in src/tnh_scholar/journal_processing/journal_process.py
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
def translate_sections(
    batch_jsonl_path: Path,
    system_message,
    section_contents,
    section_metadata,
    journal_name,
    immediate=False,
):
    """build up sections in batches to translate"""

    section_mdata = section_metadata["sections"]
    if len(section_contents) != len(section_mdata):
        raise RuntimeError("Section length mismatch.")

    # collate metadata and section content, calculate max_tokens per section:
    section_data_to_send = []
    max_token_list = []
    current_token_count = 0
    collected_translations = []
    section_last_index = len(section_mdata) - 1

    for i, section_info in enumerate(section_mdata):
        section_content = section_contents[i]
        max_tokens = floor(token_count(section_content) * 1.3) + 1000
        max_token_list.append(max_tokens)
        current_token_count += max_tokens
        section_data = SimpleNamespace(
            title=section_info["title_en"], content=section_content
        )
        section_data_to_send.append(section_data)
        logger.debug(f"section {i}: {section_data.title} added for batch processing.")

        if current_token_count >= MAX_TOKEN_LIMIT or i == section_last_index:
            # send sections for batch processing since token limit reached.
            batch_result = send_data_for_tx_batch(
                batch_jsonl_path,
                section_data_to_send,
                system_message,
                max_token_list,
                journal_name,
                immediate,
            )
            collected_translations.extend(batch_result)

            # reset containers to start building up next batch.
            section_data_to_send = []
            max_token_list = []
            current_token_count = 0

    return collected_translations
unwrap_all_lines(pages)
Source code in src/tnh_scholar/journal_processing/journal_process.py
149
150
151
152
153
154
155
156
def unwrap_all_lines(pages):
    result = []
    for page in pages:
        if page == "blank page":
            result.append(page)
        else:
            result.append(unwrap_lines(page))
    return result
unwrap_lines(text)
Removes angle brackets (< >) from encapsulated lines and merges them into
a newline-separated string.

Parameters:
    text (str): The input string with encapsulated lines.

Returns:
    str: A newline-separated string with the encapsulation removed.

Example:
    >>> merge_encapsulated_lines("<Line 1> <Line 2> <Line 3>")
    'Line 1

Line 2 Line 3' >>> merge_encapsulated_lines(" ") 'Line 1 Line 2 Line 3'

Source code in src/tnh_scholar/journal_processing/journal_process.py
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
def unwrap_lines(text: str) -> str:
    """
    Removes angle brackets (< >) from encapsulated lines and merges them into
    a newline-separated string.

    Parameters:
        text (str): The input string with encapsulated lines.

    Returns:
        str: A newline-separated string with the encapsulation removed.

    Example:
        >>> merge_encapsulated_lines("<Line 1> <Line 2> <Line 3>")
        'Line 1\nLine 2\nLine 3'
        >>> merge_encapsulated_lines("<Line 1>\n<Line 2>\n<Line 3>")
        'Line 1\nLine 2\nLine 3'
    """
    # Find all content between < and > using regex
    matches = re.findall(r"<(.*?)>", text)
    # Join the extracted content with newlines
    return "\n".join(matches)
validate_and_clean_data(data, schema)

Recursively validate and clean AI-generated data to fit the given schema. Any missing fields are filled with defaults, and extra fields are ignored.

Parameters:

Name Type Description Default
data dict

The AI-generated data to validate and clean.

required
schema dict

The schema defining the required structure.

required

Returns:

Name Type Description
dict

The cleaned data adhering to the schema.

Source code in src/tnh_scholar/journal_processing/journal_process.py
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
def validate_and_clean_data(data, schema):
    """
    Recursively validate and clean AI-generated data to fit the given schema.
    Any missing fields are filled with defaults, and extra fields are ignored.

    Args:
        data (dict): The AI-generated data to validate and clean.
        schema (dict): The schema defining the required structure.

    Returns:
        dict: The cleaned data adhering to the schema.
    """

    def clean_value(value, field_schema):
        """
        Clean a single value based on its schema, attempting type conversions where necessary.
        """
        field_type = field_schema["type"]

        # Handle type: string
        if field_type == "string":
            if isinstance(value, str):
                return value
            elif value is not None:
                return str(value)
            return "unset"

        # Handle type: integer
        elif field_type == "integer":
            if isinstance(value, int):
                return value
            elif isinstance(value, str) and value.isdigit():
                return int(value)
            try:
                return int(float(value))  # Handle cases like "2.0"
            except (ValueError, TypeError):
                return 0

        # Handle type: array
        elif field_type == "array":
            if isinstance(value, list):
                item_schema = field_schema.get("items", {})
                return [clean_value(item, item_schema) for item in value]
            elif isinstance(value, str):
                # Try splitting comma-separated strings into a list
                return [v.strip() for v in value.split(",")]
            return []

        # Handle type: object
        elif field_type == "object":
            if isinstance(value, dict):
                return validate_and_clean_data(value, field_schema)
            return {}

        # Handle nullable strings
        elif field_type == ["string", "null"]:
            if value is None or isinstance(value, str):
                return value
            return str(value)

        # Default case for unknown or unsupported types
        return "unset"

    def clean_object(obj, obj_schema):
        """
        Clean a dictionary object based on its schema.
        """
        if not isinstance(obj, dict):
            print(
                f"Expected dict but got: \n{type(obj)}: {obj}\nResetting to empty dict."
            )
            return {}
        cleaned = {}
        properties = obj_schema.get("properties", {})
        for key, field_schema in properties.items():
            # Set default value for missing fields
            cleaned[key] = clean_value(obj.get(key), field_schema)
        return cleaned

    # Handle the top-level object
    if schema["type"] == "object":
        cleaned_data = clean_object(data, schema)
        return cleaned_data
    else:
        raise ValueError("Top-level schema must be of type 'object'.")
validate_and_save_metadata(output_file_path, json_metadata_serial, schema)

Validates and cleans journal data against the schema, then writes it to a JSON file.

Parameters:

Name Type Description Default
data str

The journal data as a serialized JSON string to validate and clean.

required
schema dict

The schema defining the required structure.

required
output_file_path str

Path to the output JSON file.

required

Returns:

Name Type Description
bool

True if successfully written to the file, False otherwise.

Source code in src/tnh_scholar/journal_processing/journal_process.py
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
def validate_and_save_metadata(
    output_file_path: Path, json_metadata_serial: str, schema
):
    """
    Validates and cleans journal data against the schema, then writes it to a JSON file.

    Args:
        data (str): The journal data as a serialized JSON string to validate and clean.
        schema (dict): The schema defining the required structure.
        output_file_path (str): Path to the output JSON file.

    Returns:
        bool: True if successfully written to the file, False otherwise.
    """
    try:
        # Clean the data to fit the schema
        data = deserialize_json(json_metadata_serial)
        cleaned_data = validate_and_clean_data(data, schema)

        # Write the parsed data to the specified JSON file
        with open(output_file_path, "w", encoding="utf-8") as f:
            json.dump(cleaned_data, f, indent=4, ensure_ascii=False)
        logger.info(
            f"Parsed and validated metadata successfully written to {output_file_path}"
        )
        return True
    except Exception as e:
        logger.error(f"An error occurred during validation or writing: {e}")
        raise
wrap_all_lines(pages)
Source code in src/tnh_scholar/journal_processing/journal_process.py
122
123
def wrap_all_lines(pages):
    return [wrap_lines(page) for page in pages]
wrap_lines(text)
Encloses each line of the input text with angle brackets.

Args:
    text (str): The input string containing lines separated by '

'.

Returns:
    str: A string where each line is enclosed in angle brackets.

Example:
    >>> enclose_lines("This is a string with

two lines.") ' < two lines.>'

Source code in src/tnh_scholar/journal_processing/journal_process.py
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
def wrap_lines(text: str) -> str:
    """
    Encloses each line of the input text with angle brackets.

    Args:
        text (str): The input string containing lines separated by '\n'.

    Returns:
        str: A string where each line is enclosed in angle brackets.

    Example:
        >>> enclose_lines("This is a string with   \n   two lines.")
        '<This is a string with  >\n<    two lines.>'
    """
    return "\n".join(f"<{line}>" for line in text.split("\n"))

logging_config

BASE_LOG_DIR = Path('./logs') module-attribute

BASE_LOG_NAME = 'tnh' module-attribute

DEFAULT_CONSOLE_FORMAT_STRING = '%(asctime)s - %(name)s - %(log_color)s%(levelname)s%(reset)s - %(message)s' module-attribute

DEFAULT_FILE_FORMAT_STRING = '%(asctime)s - %(name)s - %(levelname)s - %(message)s' module-attribute

DEFAULT_LOG_FILEPATH = Path('main.log') module-attribute

LOG_COLORS = {'DEBUG': 'bold_green', 'INFO': 'cyan', 'PRIORITY_INFO': 'bold_cyan', 'WARNING': 'bold_yellow', 'ERROR': 'bold_red', 'CRITICAL': 'bold_red'} module-attribute

MAX_FILE_SIZE = 10 * 1024 * 1024 module-attribute

PRIORITY_INFO_LEVEL = 25 module-attribute

OMPFilter

Bases: Filter

Source code in src/tnh_scholar/logging_config.py
42
43
44
45
class OMPFilter(logging.Filter):
    def filter(self, record):
        # Suppress messages containing "OMP:"
        return "OMP:" not in record.getMessage()
filter(record)
Source code in src/tnh_scholar/logging_config.py
43
44
45
def filter(self, record):
    # Suppress messages containing "OMP:"
    return "OMP:" not in record.getMessage()

get_child_logger(name, console=None, separate_file=False)

Get a child logger that writes logs to a console or a specified file.

Parameters:

Name Type Description Default
name str

The name of the child logger (e.g., module name).

required
console bool

If True, log to the console. If False, do not log to the console. If None, inherit console behavior from the parent logger.

None
file Path

A string specifying a logfile to log to. will be placed under existing root logs directory. If provided, a rotating file handler will be added.

required

Returns:

Type Description

logging.Logger: Configured child logger.

Source code in src/tnh_scholar/logging_config.py
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
def get_child_logger(name: str, console: bool = None, separate_file: bool = False):
    """
    Get a child logger that writes logs to a console or a specified file.

    Args:
        name (str): The name of the child logger (e.g., module name).
        console (bool, optional): If True, log to the console. If False, do not log to the console.
                                  If None, inherit console behavior from the parent logger.
        file (Path, optional): A string specifying a logfile to log to. will be placed under existing root logs directory. If provided,
                               a rotating file handler will be added.

    Returns:
        logging.Logger: Configured child logger.
    """
    # Create the fully qualified child logger name
    full_name = f"{BASE_LOG_NAME}.{name}"
    logger = logging.getLogger(full_name)
    logger.debug(f"Created logger with name: {logger.name}")

    # Check if the logger already has handlers to avoid duplication
    if not logger.handlers:
        # Add console handler if specified
        if console:
            console_handler = colorlog.StreamHandler()
            console_formatter = colorlog.ColoredFormatter(
                DEFAULT_CONSOLE_FORMAT_STRING,
                log_colors=LOG_COLORS,
            )
            console_handler.setFormatter(console_formatter)
            logger.addHandler(console_handler)

        # Add file handler if a file path is provided
        if separate_file:
            logfile = BASE_LOG_DIR / f"{name}.log"
            logfile.parent.mkdir(parents=True, exist_ok=True)  # Ensure directory exists
            file_handler = RotatingFileHandler(
                filename=logfile,
                maxBytes=MAX_FILE_SIZE,  # Use the global MAX_FILE_SIZE
                backupCount=5,
            )
            file_formatter = logging.Formatter(DEFAULT_FILE_FORMAT_STRING)
            file_handler.setFormatter(file_formatter)
            logger.addHandler(file_handler)

        # Ensure the logger inherits handlers and settings from the base logger
        logger.propagate = True

    return logger

priority_info(self, message, *args, **kwargs)

Source code in src/tnh_scholar/logging_config.py
17
18
19
def priority_info(self, message, *args, **kwargs):
    if self.isEnabledFor(PRIORITY_INFO_LEVEL):
        self._log(PRIORITY_INFO_LEVEL, message, args, **kwargs)

setup_logging(log_level=logging.INFO, log_filepath=DEFAULT_LOG_FILEPATH, max_log_file_size=MAX_FILE_SIZE, backup_count=5, console_format=DEFAULT_CONSOLE_FORMAT_STRING, file_format=DEFAULT_FILE_FORMAT_STRING, console=True, suppressed_modules=None)

Configure the base logger with handlers, including the custom PRIORITY_INFO level.

Parameters:

Name Type Description Default
log_level int

Logging level for the base logger.

INFO
log_file_path Path

Path to the log file.

required
max_log_file_size int

Maximum log file size in bytes.

MAX_FILE_SIZE
backup_count int

Number of backup log files to keep.

5
console_format str

Format string for console logs.

DEFAULT_CONSOLE_FORMAT_STRING
file_format str

Format string for file logs.

DEFAULT_FILE_FORMAT_STRING
suppressed_modules list

List of third-party modules to suppress logs for.

None
Source code in src/tnh_scholar/logging_config.py
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
def setup_logging(
    log_level=logging.INFO,
    log_filepath=DEFAULT_LOG_FILEPATH,
    max_log_file_size=MAX_FILE_SIZE,  # 10MB
    backup_count=5,
    console_format=DEFAULT_CONSOLE_FORMAT_STRING,
    file_format=DEFAULT_FILE_FORMAT_STRING,
    console=True,
    suppressed_modules=None,
):
    """
    Configure the base logger with handlers, including the custom PRIORITY_INFO level.

    Args:
        log_level (int): Logging level for the base logger.
        log_file_path (Path): Path to the log file.
        max_log_file_size (int): Maximum log file size in bytes.
        backup_count (int): Number of backup log files to keep.
        console_format (str): Format string for console logs.
        file_format (str): Format string for file logs.
        suppressed_modules (list): List of third-party modules to suppress logs for.
    """
    # Create the base logger
    log_file_path = BASE_LOG_DIR / log_filepath
    base_logger = logging.getLogger(BASE_LOG_NAME)
    base_logger.setLevel(log_level)

    # Clear existing handlers to avoid duplication
    base_logger.handlers.clear()

    # Colorized console handler
    if console:
        console_handler = colorlog.StreamHandler()
        console_formatter = colorlog.ColoredFormatter(
            console_format,
            log_colors=LOG_COLORS,
        )
        console_handler.setFormatter(console_formatter)
        # Add the OMP filter to the console handler
        console_handler.addFilter(OMPFilter())
        base_logger.addHandler(console_handler)

    # Plain file handler
    log_file_path.parent.mkdir(parents=True, exist_ok=True)
    file_handler = RotatingFileHandler(
        filename=log_file_path,
        maxBytes=max_log_file_size,
        backupCount=backup_count,
    )
    file_formatter = logging.Formatter(file_format)
    file_handler.setFormatter(file_formatter)
    base_logger.addHandler(file_handler)

    # Suppress noisy third-party logs
    if suppressed_modules:
        for module in suppressed_modules:
            logging.getLogger(module).setLevel(logging.WARNING)

    # Prevent propagation to the root logger
    base_logger.propagate = False

    return base_logger

ocr_processing

DEFAULT_ANNOTATION_FONT_PATH = Path('/System/Library/Fonts/Supplemental/Arial.ttf') module-attribute

DEFAULT_ANNOTATION_FONT_SIZE = 12 module-attribute

DEFAULT_ANNOTATION_LANGUAGE_HINTS = ['vi'] module-attribute

DEFAULT_ANNOTATION_METHOD = 'DOCUMENT_TEXT_DETECTION' module-attribute

DEFAULT_ANNOTATION_OFFSET = 2 module-attribute

logger = logging.getLogger('ocr_processing') module-attribute

PDFParseWarning

Bases: Warning

Custom warning class for PDF parsing issues. Encapsulates minimal logic for displaying warnings with a custom format.

Source code in src/tnh_scholar/ocr_processing/ocr_processing.py
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
class PDFParseWarning(Warning):
    """
    Custom warning class for PDF parsing issues.
    Encapsulates minimal logic for displaying warnings with a custom format.
    """

    @staticmethod
    def warn(message: str):
        """
        Display a warning message with custom formatting.

        Parameters:
            message (str): The warning message to display.
        """
        formatted_message = f"\033[93mPDFParseWarning: {message}\033[0m"
        print(formatted_message)  # Simply prints the warning
warn(message) staticmethod

Display a warning message with custom formatting.

Parameters:

Name Type Description Default
message str

The warning message to display.

required
Source code in src/tnh_scholar/ocr_processing/ocr_processing.py
30
31
32
33
34
35
36
37
38
39
@staticmethod
def warn(message: str):
    """
    Display a warning message with custom formatting.

    Parameters:
        message (str): The warning message to display.
    """
    formatted_message = f"\033[93mPDFParseWarning: {message}\033[0m"
    print(formatted_message)  # Simply prints the warning

annotate_image_with_text(image, text_annotations, annotation_font_path, font_size=12)

Annotates a PIL image with bounding boxes and text descriptions from OCR results.

Parameters:

Name Type Description Default
pil_image Image

The input PIL image to annotate.

required
text_annotations List[EntityAnnotation]

OCR results containing bounding boxes and text.

required
annotation_font_path str

Path to the font file for text annotations.

required
font_size int

Font size for text annotations.

12

Returns:

Type Description
Image

Image.Image: The annotated PIL image.

Raises:

Type Description
ValueError

If the input image is None.

IOError

If the font file cannot be loaded.

Exception

For any other unexpected errors.

Source code in src/tnh_scholar/ocr_processing/ocr_processing.py
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
def annotate_image_with_text(
    image: Image.Image,
    text_annotations: List[EntityAnnotation],
    annotation_font_path: str,
    font_size: int = 12,
) -> Image.Image:
    """
    Annotates a PIL image with bounding boxes and text descriptions from OCR results.

    Parameters:
        pil_image (Image.Image): The input PIL image to annotate.
        text_annotations (List[EntityAnnotation]): OCR results containing bounding boxes and text.
        annotation_font_path (str): Path to the font file for text annotations.
        font_size (int): Font size for text annotations.

    Returns:
        Image.Image: The annotated PIL image.

    Raises:
        ValueError: If the input image is None.
        IOError: If the font file cannot be loaded.
        Exception: For any other unexpected errors.
    """
    if image is None:
        raise ValueError("The input image is None.")

    try:
        font = ImageFont.truetype(annotation_font_path, font_size)
    except IOError as e:
        raise IOError(f"Failed to load the font from '{annotation_font_path}': {e}")

    draw = ImageDraw.Draw(image)

    try:
        for i, text_obj in enumerate(text_annotations):
            vertices = [
                (vertex.x, vertex.y) for vertex in text_obj.bounding_poly.vertices
            ]
            if (
                len(vertices) == 4
            ):  # Ensure there are exactly 4 vertices for a rectangle
                # Draw the bounding box
                draw.polygon(vertices, outline="red", width=2)

                # Skip the first bounding box (whole text region)
                if i > 0:
                    # Offset the text position slightly for clarity
                    text_position = (vertices[0][0] + 2, vertices[0][1] + 2)
                    draw.text(
                        text_position, text_obj.description, fill="red", font=font
                    )

    except AttributeError as e:
        raise ValueError(f"Invalid text annotation structure: {e}")
    except Exception as e:
        raise Exception(f"An error occurred during image annotation: {e}")

    return image

build_processed_pdf(pdf_path, client, preprocessor=None, annotation_font_path=DEFAULT_ANNOTATION_FONT_PATH)

Processes a PDF document, extracting text, word locations, annotated images, and unannotated images.

Parameters:

Name Type Description Default
pdf_path Path

Path to the PDF file.

required
client ImageAnnotatorClient

Google Vision API client for text detection.

required
annotation_font_path Path

Path to the font file for annotations.

DEFAULT_ANNOTATION_FONT_PATH

Returns:

Type Description
Tuple[List[str], List[List[EntityAnnotation]], List[Image], List[Image]]

Tuple[List[str], List[List[vision.EntityAnnotation]], List[Image.Image], List[Image.Image]]: - List of extracted full-page texts (one entry per page). - List of word locations (list of vision.EntityAnnotation objects for each page). - List of annotated images (with bounding boxes and text annotations). - List of unannotated images (raw page images).

Raises:

Type Description
FileNotFoundError

If the specified PDF file does not exist.

ValueError

If the PDF file is invalid or contains no pages.

Exception

For any unexpected errors during processing.

Example

from pathlib import Path from google.cloud import vision pdf_path = Path("/path/to/example.pdf") font_path = Path("/path/to/fonts/Arial.ttf") client = vision.ImageAnnotatorClient() try: text_pages, word_locations_list, annotated_images, unannotated_images = build_processed_pdf( pdf_path, client, font_path ) print(f"Processed {len(text_pages)} pages successfully!") except Exception as e: print(f"Error processing PDF: {e}")

Source code in src/tnh_scholar/ocr_processing/ocr_processing.py
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
def build_processed_pdf(
    pdf_path: Path,
    client: vision.ImageAnnotatorClient,
    preprocessor: Callable = None,
    annotation_font_path: Path = DEFAULT_ANNOTATION_FONT_PATH,
) -> Tuple[
    List[str], List[List[vision.EntityAnnotation]], List[Image.Image], List[Image.Image]
]:
    """
    Processes a PDF document, extracting text, word locations, annotated images, and unannotated images.

    Parameters:
        pdf_path (Path): Path to the PDF file.
        client (vision.ImageAnnotatorClient): Google Vision API client for text detection.
        annotation_font_path (Path): Path to the font file for annotations.

    Returns:
        Tuple[List[str], List[List[vision.EntityAnnotation]], List[Image.Image], List[Image.Image]]:
            - List of extracted full-page texts (one entry per page).
            - List of word locations (list of `vision.EntityAnnotation` objects for each page).
            - List of annotated images (with bounding boxes and text annotations).
            - List of unannotated images (raw page images).

    Raises:
        FileNotFoundError: If the specified PDF file does not exist.
        ValueError: If the PDF file is invalid or contains no pages.
        Exception: For any unexpected errors during processing.

    Example:
        >>> from pathlib import Path
        >>> from google.cloud import vision
        >>> pdf_path = Path("/path/to/example.pdf")
        >>> font_path = Path("/path/to/fonts/Arial.ttf")
        >>> client = vision.ImageAnnotatorClient()
        >>> try:
        >>>     text_pages, word_locations_list, annotated_images, unannotated_images = build_processed_pdf(
        >>>         pdf_path, client, font_path
        >>>     )
        >>>     print(f"Processed {len(text_pages)} pages successfully!")
        >>> except Exception as e:
        >>>     print(f"Error processing PDF: {e}")
    """
    try:
        doc = load_pdf_pages(pdf_path)
    except FileNotFoundError as fnf_error:
        raise FileNotFoundError(f"Error loading PDF: {fnf_error}")
    except ValueError as ve:
        raise ValueError(f"Invalid PDF file: {ve}")
    except Exception as e:
        raise Exception(f"An unexpected error occurred while loading the PDF: {e}")

    if doc.page_count == 0:
        raise ValueError(f"The PDF file '{pdf_path}' contains no pages.")

    logger.info(f"Processing file with {doc.page_count} pages:\n\t{pdf_path}")

    text_pages = []
    word_locations_list = []
    annotated_images = []
    unannotated_images = []
    first_page_dimensions = None

    for page_num in range(doc.page_count):
        logger.info(f"Processing page {page_num + 1}/{doc.page_count}...")

        try:
            page = doc.load_page(page_num)
            (
                full_page_text,
                word_locations,
                annotated_image,
                unannotated_image,
                page_dimensions,
            ) = process_page(page, client, annotation_font_path, preprocessor)

            if full_page_text:  # this is not an empty page

                if page_num == 0:  # save first page info
                    first_page_dimensions = page_dimensions
                elif (
                    page_dimensions != first_page_dimensions
                ):  # verify page dimensions are consistent
                    PDFParseWarning.warn(
                        f"Page {page_num + 1} has different dimensions than page 1."
                        f"({page_dimensions}) compared to the first page: ({first_page_dimensions})."
                    )

                text_pages.append(full_page_text)
                word_locations_list.append(word_locations)
                annotated_images.append(annotated_image)
                unannotated_images.append(unannotated_image)
            else:
                PDFParseWarning.warn(
                    f"Page {page_num + 1} empty, added empty datastructures...\n"
                    # f"  (Note that total document length will be reduced.)"
                )

        except ValueError as ve:
            print(f"ValueError on page {page_num + 1}: {ve}")
        except OSError as oe:
            print(f"OSError on page {page_num + 1}: {oe}")
        except Exception as e:
            print(f"Unexpected error on page {page_num + 1}: {e}")

    print(f"page dimensions: {page_dimensions}")
    return text_pages, word_locations_list, annotated_images, unannotated_images

deserialize_entity_annotations_from_json(data)

Deserializes JSON data into a nested list of EntityAnnotation objects.

Parameters:

Name Type Description Default
data str

The JSON string containing serialized annotations.

required

Returns:

Type Description
List[List[EntityAnnotation]]

List[List[EntityAnnotation]]: The reconstructed nested list of EntityAnnotation objects.

Source code in src/tnh_scholar/ocr_processing/ocr_processing.py
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
def deserialize_entity_annotations_from_json(data: str) -> List[List[EntityAnnotation]]:
    """
    Deserializes JSON data into a nested list of EntityAnnotation objects.

    Parameters:
        data (str): The JSON string containing serialized annotations.

    Returns:
        List[List[EntityAnnotation]]: The reconstructed nested list of EntityAnnotation objects.
    """
    serialized_data = json.loads(data)
    deserialized_data = []

    for serialized_page in serialized_data:
        page_annotations = [
            EntityAnnotation.deserialize(base64.b64decode(serialized_annotation))
            for serialized_annotation in serialized_page
        ]
        deserialized_data.append(page_annotations)

    return deserialized_data

extract_image_from_page(page)

Extracts the first image from the given PDF page and returns it as a PIL Image.

Parameters:

Name Type Description Default
page Page

The PDF page object.

required

Returns:

Type Description
Image

Image.Image: The first image on the page as a Pillow Image object.

Raises:

Type Description
ValueError

If no images are found on the page or the image data is incomplete.

Exception

For unexpected errors during image extraction.

Example

import fitz from PIL import Image doc = fitz.open("/path/to/document.pdf") page = doc.load_page(0) # Load the first page try: image = extract_image_from_page(page) image.show() # Display the image except Exception as e: print(f"Error extracting image: {e}")

Source code in src/tnh_scholar/ocr_processing/ocr_processing.py
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
def extract_image_from_page(page: fitz.Page) -> Image.Image:
    """
    Extracts the first image from the given PDF page and returns it as a PIL Image.

    Parameters:
        page (fitz.Page): The PDF page object.

    Returns:
        Image.Image: The first image on the page as a Pillow Image object.

    Raises:
        ValueError: If no images are found on the page or the image data is incomplete.
        Exception: For unexpected errors during image extraction.

    Example:
        >>> import fitz
        >>> from PIL import Image
        >>> doc = fitz.open("/path/to/document.pdf")
        >>> page = doc.load_page(0)  # Load the first page
        >>> try:
        >>>     image = extract_image_from_page(page)
        >>>     image.show()  # Display the image
        >>> except Exception as e:
        >>>     print(f"Error extracting image: {e}")
    """
    try:
        # Get images from the page
        images = page.get_images(full=True)
        if not images:
            raise ValueError("No images found on the page.")

        # Extract the first image reference
        xref = images[0][0]  # Get the first image's xref
        base_image = page.parent.extract_image(xref)

        # Validate the extracted image data
        if (
            "image" not in base_image
            or "width" not in base_image
            or "height" not in base_image
        ):
            raise ValueError("The extracted image data is incomplete.")

        # Convert the raw image bytes into a Pillow image
        image_bytes = base_image["image"]
        pil_image = Image.open(io.BytesIO(image_bytes)).convert("RGB")

        return pil_image

    except ValueError as ve:
        raise ve  # Re-raise for calling functions to handle
    except Exception as e:
        raise Exception(f"An unexpected error occurred during image extraction: {e}")

get_page_dimensions(page)

Extracts the width and height of a single PDF page in both inches and pixels.

Parameters:

Name Type Description Default
page Page

A single PDF page object from PyMuPDF.

required

Returns:

Name Type Description
dict dict

A dictionary containing the width and height of the page in inches and pixels.

Source code in src/tnh_scholar/ocr_processing/ocr_processing.py
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
def get_page_dimensions(page: fitz.Page) -> dict:
    """
    Extracts the width and height of a single PDF page in both inches and pixels.

    Args:
        page (fitz.Page): A single PDF page object from PyMuPDF.

    Returns:
        dict: A dictionary containing the width and height of the page in inches and pixels.
    """
    # Get page dimensions in points and convert to inches
    page_width_pts, page_height_pts = page.rect.width, page.rect.height
    page_width_in = page_width_pts / 72  # Convert points to inches
    page_height_in = page_height_pts / 72

    # Extract the first image on the page (if any) to get pixel dimensions
    images = page.get_images(full=True)
    if images:
        xref = images[0][0]
        base_image = page.parent.extract_image(xref)
        width_px = base_image["width"]
        height_px = base_image["height"]
    else:
        width_px, height_px = None, None  # No image found on the page

    # Return dimensions
    return {
        "width_in": page_width_in,
        "height_in": page_height_in,
        "width_px": width_px,
        "height_px": height_px,
    }

load_pdf_pages(pdf_path)

Opens the PDF document and returns the fitz Document object.

Parameters:

Name Type Description Default
pdf_path Path

The path to the PDF file.

required

Returns:

Type Description
Document

fitz.Document: The loaded PDF document.

Raises:

Type Description
FileNotFoundError

If the specified file does not exist.

ValueError

If the file is not a valid PDF document.

Exception

For any unexpected error.

Example

from pathlib import Path pdf_path = Path("/path/to/example.pdf") try: pdf_doc = load_pdf_pages(pdf_path) print(f"PDF contains {pdf_doc.page_count} pages.") except Exception as e: print(f"Error loading PDF: {e}")

Source code in src/tnh_scholar/ocr_processing/ocr_processing.py
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
def load_pdf_pages(pdf_path: Path) -> fitz.Document:
    """
    Opens the PDF document and returns the fitz Document object.

    Parameters:
        pdf_path (Path): The path to the PDF file.

    Returns:
        fitz.Document: The loaded PDF document.

    Raises:
        FileNotFoundError: If the specified file does not exist.
        ValueError: If the file is not a valid PDF document.
        Exception: For any unexpected error.

    Example:
        >>> from pathlib import Path
        >>> pdf_path = Path("/path/to/example.pdf")
        >>> try:
        >>>     pdf_doc = load_pdf_pages(pdf_path)
        >>>     print(f"PDF contains {pdf_doc.page_count} pages.")
        >>> except Exception as e:
        >>>     print(f"Error loading PDF: {e}")
    """
    if not pdf_path.exists():
        raise FileNotFoundError(f"The file '{pdf_path}' does not exist.")

    if not pdf_path.suffix.lower() == ".pdf":
        raise ValueError(
            f"The file '{pdf_path}' is not a valid PDF document (expected '.pdf')."
        )

    try:
        return fitz.open(str(pdf_path))  # PyMuPDF expects a string path
    except Exception as e:
        raise Exception(f"An unexpected error occurred while opening the PDF: {e}")

load_processed_PDF_data(base_path)

Loads processed PDF data from files using metadata for file references.

Parameters:

Name Type Description Default
output_dir Path

Directory where the data is stored (as a Path object).

required
base_name str

Base name of the processed directory.

required

Returns:

Type Description
Tuple[List[str], List[List[EntityAnnotation]], List[Image], List[Image]]

Tuple[List[str], List[List[EntityAnnotation]], List[Image.Image], List[Image.Image]]: - Loaded text pages. - Word locations (list of EntityAnnotation objects for each page). - Annotated images. - Unannotated images.

Raises:

Type Description
FileNotFoundError

If any required files are missing.

ValueError

If the metadata file is incomplete or invalid.

Source code in src/tnh_scholar/ocr_processing/ocr_processing.py
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
def load_processed_PDF_data(
    base_path: Path,
) -> Tuple[
    List[str], List[List[EntityAnnotation]], List[Image.Image], List[Image.Image]
]:
    """
    Loads processed PDF data from files using metadata for file references.

    Parameters:
        output_dir (Path): Directory where the data is stored (as a Path object).
        base_name (str): Base name of the processed directory.

    Returns:
        Tuple[List[str], List[List[EntityAnnotation]], List[Image.Image], List[Image.Image]]:
            - Loaded text pages.
            - Word locations (list of `EntityAnnotation` objects for each page).
            - Annotated images.
            - Unannotated images.

    Raises:
        FileNotFoundError: If any required files are missing.
        ValueError: If the metadata file is incomplete or invalid.
    """
    metadata_file = base_path / "metadata.json"

    # Load metadata
    try:
        with metadata_file.open("r", encoding="utf-8") as f:
            metadata = json.load(f)
    except FileNotFoundError:
        raise FileNotFoundError(f"Metadata file '{metadata_file}' not found.")
    except json.JSONDecodeError as e:
        raise ValueError(f"Invalid metadata file format: {e}")

    # Extract file paths from metadata
    text_pages_file = base_path / metadata.get("files", {}).get(
        "text_pages", "text_pages.json"
    )
    word_locations_file = base_path / metadata.get("files", {}).get(
        "word_locations", "word_locations.json"
    )
    images_dir = Path(metadata.get("images_directory", base_path / "images"))

    # Validate file paths
    if not text_pages_file.exists():
        raise FileNotFoundError(f"Text pages file '{text_pages_file}' not found.")
    if not word_locations_file.exists():
        raise FileNotFoundError(
            f"Word locations file '{word_locations_file}' not found."
        )
    if not images_dir.exists() or not images_dir.is_dir():
        raise FileNotFoundError(f"Images directory '{images_dir}' not found.")

    # Load text pages
    with text_pages_file.open("r", encoding="utf-8") as f:
        text_pages = json.load(f)

    # Load word locations
    with word_locations_file.open("r", encoding="utf-8") as f:
        serialized_word_locations = f.read()
        word_locations = deserialize_entity_annotations_from_json(
            serialized_word_locations
        )

    # Load images
    annotated_images = []
    unannotated_images = []
    for file in sorted(
        images_dir.iterdir()
    ):  # Iterate over files in the images directory
        if file.name.startswith("annotated_page_") and file.suffix == ".png":
            annotated_images.append(Image.open(file))
        elif file.name.startswith("unannotated_page_") and file.suffix == ".png":
            unannotated_images.append(Image.open(file))

    # Ensure images were loaded correctly
    if not annotated_images or not unannotated_images:
        raise ValueError(f"No images found in the directory '{images_dir}'.")

    return text_pages, word_locations, annotated_images, unannotated_images

make_image_preprocess_mask(mask_height)

Creates a preprocessing function that masks a specified height at the bottom of the image.

Parameters:

Name Type Description Default
mask_height float

The proportion of the image height to mask at the bottom (0.0 to 1.0).

required

Returns:

Type Description
Callable[[Image, int], Image]

Callable[[Image.Image, int], Image.Image]: A preprocessing function that takes an image

Callable[[Image, int], Image]

and page number as input and returns the processed image.

Source code in src/tnh_scholar/ocr_processing/ocr_processing.py
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
def make_image_preprocess_mask(
    mask_height: float,
) -> Callable[[Image.Image, int], Image.Image]:
    """
    Creates a preprocessing function that masks a specified height at the bottom of the image.

    Parameters:
        mask_height (float): The proportion of the image height to mask at the bottom (0.0 to 1.0).

    Returns:
        Callable[[Image.Image, int], Image.Image]: A preprocessing function that takes an image
        and page number as input and returns the processed image.
    """

    def pre_process_image(image: Image.Image, page_number: int) -> Image.Image:
        """
        Preprocesses the image by masking the bottom region or performing other preprocessing steps.

        Parameters:
            image (Image.Image): The input image as a Pillow object.
            page_number (int): The page number of the image (useful for conditional preprocessing).

        Returns:
            Image.Image: The preprocessed image.
        """

        if page_number > 0:  # don't apply mask to cover page.
            # Make a copy of the image to avoid modifying the original
            draw = ImageDraw.Draw(image)

            # Get image dimensions
            width, height = image.size

            # Mask the bottom region based on the specified height proportion
            mask_pixels = int(height * mask_height)
            draw.rectangle([(0, height - mask_pixels), (width, height)], fill="black")

        return image

    return pre_process_image

pil_to_bytes(image, format='PNG')

Converts a Pillow image to raw bytes.

Parameters:

Name Type Description Default
image Image

The Pillow image object to convert.

required
format str

The format to save the image as (e.g., "PNG", "JPEG"). Default is "PNG".

'PNG'

Returns:

Name Type Description
bytes bytes

The raw bytes of the image.

Source code in src/tnh_scholar/ocr_processing/ocr_processing.py
42
43
44
45
46
47
48
49
50
51
52
53
54
55
def pil_to_bytes(image: Image.Image, format: str = "PNG") -> bytes:
    """
    Converts a Pillow image to raw bytes.

    Parameters:
        image (Image.Image): The Pillow image object to convert.
        format (str): The format to save the image as (e.g., "PNG", "JPEG"). Default is "PNG".

    Returns:
        bytes: The raw bytes of the image.
    """
    with io.BytesIO() as output:
        image.save(output, format=format)
        return output.getvalue()

process_page(page, client, annotation_font_path, preprocessor=None)

Processes a single PDF page, extracting text, word locations, and annotated images.

Parameters:

Name Type Description Default
page Page

The PDF page object.

required
client ImageAnnotatorClient

Google Vision API client for text detection.

required
pre_processor Callable[[Image, int], Image]

Preprocessing function for the image.

required
annotation_font_path str

Path to the font file for annotations.

required

Returns:

Type Description
Tuple[str, List[EntityAnnotation], Image, Image, dict]

Tuple[str, List[vision.EntityAnnotation], Image.Image, Image.Image, dict]: - Full page text (str) - Word locations (List of vision.EntityAnnotation) - Annotated image (Pillow Image object) - Original unprocessed image (Pillow Image object) - Page dimensions (dict)

Source code in src/tnh_scholar/ocr_processing/ocr_processing.py
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
def process_page(
    page: fitz.Page,
    client: vision.ImageAnnotatorClient,
    annotation_font_path: str,
    preprocessor: Callable[[Image.Image, int], Image.Image] = None,
) -> Tuple[str, List[vision.EntityAnnotation], Image.Image, Image.Image, dict]:
    """
    Processes a single PDF page, extracting text, word locations, and annotated images.

    Parameters:
        page (fitz.Page): The PDF page object.
        client (vision.ImageAnnotatorClient): Google Vision API client for text detection.
        pre_processor (Callable[[Image.Image, int], Image.Image]): Preprocessing function for the image.
        annotation_font_path (str): Path to the font file for annotations.

    Returns:
        Tuple[str, List[vision.EntityAnnotation], Image.Image, Image.Image, dict]:
            - Full page text (str)
            - Word locations (List of vision.EntityAnnotation)
            - Annotated image (Pillow Image object)
            - Original unprocessed image (Pillow Image object)
            - Page dimensions (dict)
    """
    # Extract the original image from the PDF page
    original_image = extract_image_from_page(page)

    # Make a copy of the original image for processing
    processed_image = original_image.copy()

    # Apply the preprocessing function (if provided)
    if preprocessor:
        # print("preprocessing...") # debug
        processed_image = preprocessor(processed_image, page.number)
        # processed_image.show() # debug

    # Annotate the processed image using the Vision API
    response = process_single_image(processed_image, client)

    if response:
        text_annotations = response.text_annotations
        # Extract full text and word locations
        full_page_text = text_annotations[0].description if text_annotations else ""
        word_locations = text_annotations[1:] if len(text_annotations) > 1 else []
    else:
        # return empty data
        full_page_text = ""
        word_locations = [EntityAnnotation()]
        text_annotations = [
            EntityAnnotation()
        ]  # create empty data structures to allow storing to proceed.

    # Create an annotated image with bounding boxes and labels
    annotated_image = annotate_image_with_text(
        processed_image, text_annotations, annotation_font_path
    )

    # Get page dimensions (from the original PDF page, not the image)
    page_dimensions = get_page_dimensions(page)

    return (
        full_page_text,
        word_locations,
        annotated_image,
        original_image,
        page_dimensions,
    )

process_single_image(image, client, feature_type=DEFAULT_ANNOTATION_METHOD, language_hints=DEFAULT_ANNOTATION_LANGUAGE_HINTS)

Processes a single image with the Google Vision API and returns text annotations.

Parameters:

Name Type Description Default
image Image

The preprocessed Pillow image object.

required
client ImageAnnotatorClient

Google Vision API client for text detection.

required
feature_type str

Type of text detection to use ('TEXT_DETECTION' or 'DOCUMENT_TEXT_DETECTION').

DEFAULT_ANNOTATION_METHOD
language_hints List

Language hints for OCR.

DEFAULT_ANNOTATION_LANGUAGE_HINTS

Returns:

Type Description
List[EntityAnnotation]

List[vision.EntityAnnotation]: Text annotations from the Vision API response.

Raises:

Type Description
ValueError

If no text is detected.

Source code in src/tnh_scholar/ocr_processing/ocr_processing.py
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
def process_single_image(
    image: Image.Image,
    client: vision.ImageAnnotatorClient,
    feature_type: str = DEFAULT_ANNOTATION_METHOD,
    language_hints: List = DEFAULT_ANNOTATION_LANGUAGE_HINTS,
) -> List[vision.EntityAnnotation]:
    """
    Processes a single image with the Google Vision API and returns text annotations.

    Parameters:
        image (Image.Image): The preprocessed Pillow image object.
        client (vision.ImageAnnotatorClient): Google Vision API client for text detection.
        feature_type (str): Type of text detection to use ('TEXT_DETECTION' or 'DOCUMENT_TEXT_DETECTION').
        language_hints (List): Language hints for OCR.

    Returns:
        List[vision.EntityAnnotation]: Text annotations from the Vision API response.

    Raises:
        ValueError: If no text is detected.
    """
    # Convert the Pillow image to bytes
    image_bytes = pil_to_bytes(image, format="PNG")

    # Map feature type
    feature_map = {
        "TEXT_DETECTION": vision.Feature.Type.TEXT_DETECTION,
        "DOCUMENT_TEXT_DETECTION": vision.Feature.Type.DOCUMENT_TEXT_DETECTION,
    }
    if feature_type not in feature_map:
        raise ValueError(
            f"Invalid feature type '{feature_type}'. Use 'TEXT_DETECTION' or 'DOCUMENT_TEXT_DETECTION'."
        )

    # Prepare Vision API request
    vision_image = vision.Image(content=image_bytes)
    features = [vision.Feature(type=feature_map[feature_type])]
    image_context = vision.ImageContext(language_hints=language_hints)

    # Make the API call
    response = client.annotate_image(
        {"image": vision_image, "features": features, "image_context": image_context}
    )

    return response

save_processed_pdf_data(output_dir, journal_name, text_pages, word_locations, annotated_images, unannotated_images)

Saves processed PDF data to files for later reloading.

Parameters:

Name Type Description Default
output_dir Path

Directory to save the data (as a Path object).

required
base_name str

Base name for the output directory (usually the PDF name without extension).

required
text_pages List[str]

Extracted full-page text.

required
word_locations List[List[EntityAnnotation]]

Word locations and annotations from Vision API.

required
annotated_images List[Image]

Annotated images with bounding boxes.

required
unannotated_images List[Image]

Raw unannotated images.

required

Returns:

Type Description

None

Source code in src/tnh_scholar/ocr_processing/ocr_processing.py
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
def save_processed_pdf_data(
    output_dir: Path,
    journal_name: str,
    text_pages: List[str],
    word_locations: List[List[EntityAnnotation]],
    annotated_images: List[Image.Image],
    unannotated_images: List[Image.Image],
):
    """
    Saves processed PDF data to files for later reloading.

    Parameters:
        output_dir (Path): Directory to save the data (as a Path object).
        base_name (str): Base name for the output directory (usually the PDF name without extension).
        text_pages (List[str]): Extracted full-page text.
        word_locations (List[List[EntityAnnotation]]): Word locations and annotations from Vision API.
        annotated_images (List[PIL.Image.Image]): Annotated images with bounding boxes.
        unannotated_images (List[PIL.Image.Image]): Raw unannotated images.

    Returns:
        None
    """
    # Create output directories
    base_path = output_dir / journal_name / "ocr_data"
    images_dir = base_path / "images"

    base_path.mkdir(parents=True, exist_ok=True)
    images_dir.mkdir(parents=True, exist_ok=True)

    # Save text data
    text_pages_file = base_path / "text_pages.json"
    with text_pages_file.open("w", encoding="utf-8") as f:
        json.dump(text_pages, f, indent=4, ensure_ascii=False)

    # Save word locations as JSON
    word_locations_file = base_path / "word_locations.json"
    serialized_word_locations = serialize_entity_annotations_to_json(word_locations)
    with word_locations_file.open("w", encoding="utf-8") as f:
        f.write(serialized_word_locations)

    # Save images
    for i, annotated_image in enumerate(annotated_images):
        annotated_file = images_dir / f"annotated_page_{i + 1}.png"
        annotated_image.save(annotated_file)
    for i, unannotated_image in enumerate(unannotated_images):
        unannotated_file = images_dir / f"unannotated_page_{i + 1}.png"
        unannotated_image.save(unannotated_file)

    # Save metadata
    metadata = {
        "source_pdf": journal_name,
        "page_count": len(text_pages),
        "images_directory": str(
            images_dir
        ),  # Convert Path to string for JSON serialization
        "files": {
            "text_pages": "text_pages.json",
            "word_locations": "word_locations.json",
        },
    }
    metadata_file = base_path / "metadata.json"
    with metadata_file.open("w", encoding="utf-8") as f:
        json.dump(metadata, f, indent=4)

    print(f"Processed data saved in: {base_path}")

serialize_entity_annotations_to_json(annotations)

Serializes a nested list of EntityAnnotation objects into a JSON-compatible format using Base64 encoding.

Parameters:

Name Type Description Default
annotations List[List[EntityAnnotation]]

The nested list of EntityAnnotation objects.

required

Returns:

Name Type Description
str str

The serialized data in JSON format as a string.

Source code in src/tnh_scholar/ocr_processing/ocr_processing.py
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
def serialize_entity_annotations_to_json(
    annotations: List[List[EntityAnnotation]],
) -> str:
    """
    Serializes a nested list of EntityAnnotation objects into a JSON-compatible format using Base64 encoding.

    Parameters:
        annotations (List[List[EntityAnnotation]]): The nested list of EntityAnnotation objects.

    Returns:
        str: The serialized data in JSON format as a string.
    """
    serialized_data = []
    for page_annotations in annotations:
        serialized_page = [
            base64.b64encode(annotation.SerializeToString()).decode("utf-8")
            for annotation in page_annotations
        ]
        serialized_data.append(serialized_page)

    # Convert to a JSON string
    return json.dumps(serialized_data, indent=4)

start_image_annotator_client(credentials_file=None, api_endpoint='vision.googleapis.com', timeout=(10, 30), enable_logging=False)

Starts and returns a Google Vision API ImageAnnotatorClient with optional configuration.

Parameters:

Name Type Description Default
credentials_file str

Path to the credentials JSON file. If None, uses the default environment variable.

None
api_endpoint str

Custom API endpoint for the Vision API. Default is the global endpoint.

'vision.googleapis.com'
timeout Tuple[int, int]

Connection and read timeouts in seconds. Default is (10, 30).

(10, 30)
enable_logging bool

Enable detailed logging for debugging. Default is False.

False

Returns:

Type Description
ImageAnnotatorClient

vision.ImageAnnotatorClient: Configured Vision API client.

Raises:

Type Description
FileNotFoundError

If the specified credentials file is not found.

Exception

For unexpected errors during client setup.

Example

client = start_image_annotator_client( credentials_file="/path/to/credentials.json", api_endpoint="vision.googleapis.com", timeout=(10, 30), enable_logging=True ) print("Google Vision API client initialized.")

Source code in src/tnh_scholar/ocr_processing/ocr_processing.py
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
def start_image_annotator_client(
    credentials_file: str = None,
    api_endpoint: str = "vision.googleapis.com",
    timeout: Tuple[int, int] = (10, 30),
    enable_logging: bool = False,
) -> vision.ImageAnnotatorClient:
    """
    Starts and returns a Google Vision API ImageAnnotatorClient with optional configuration.

    Parameters:
        credentials_file (str): Path to the credentials JSON file. If None, uses the default environment variable.
        api_endpoint (str): Custom API endpoint for the Vision API. Default is the global endpoint.
        timeout (Tuple[int, int]): Connection and read timeouts in seconds. Default is (10, 30).
        enable_logging (bool): Enable detailed logging for debugging. Default is False.

    Returns:
        vision.ImageAnnotatorClient: Configured Vision API client.

    Raises:
        FileNotFoundError: If the specified credentials file is not found.
        Exception: For unexpected errors during client setup.

    Example:
        >>> client = start_image_annotator_client(
        >>>     credentials_file="/path/to/credentials.json",
        >>>     api_endpoint="vision.googleapis.com",
        >>>     timeout=(10, 30),
        >>>     enable_logging=True
        >>> )
        >>> print("Google Vision API client initialized.")
    """
    try:
        # Set up credentials
        if credentials_file:
            if not os.path.exists(credentials_file):
                raise FileNotFoundError(
                    f"Credentials file '{credentials_file}' not found."
                )
            os.environ["GOOGLE_APPLICATION_CREDENTIALS"] = credentials_file

        # Configure client options
        client_options = {"api_endpoint": api_endpoint}
        client = vision.ImageAnnotatorClient(client_options=client_options)

        # Optionally enable logging
        if enable_logging:
            print(f"Vision API Client started with endpoint: {api_endpoint}")
            print(f"Timeout settings: Connect={timeout[0]}s, Read={timeout[1]}s")

        return client

    except Exception as e:
        raise Exception(f"Failed to initialize ImageAnnotatorClient: {e}")

ocr_editor

current_image = st.session_state.current_image module-attribute
current_page_index = st.session_state.current_page_index module-attribute
current_text = pages[current_page_index] module-attribute
edited_text = st.text_area('Edit OCR Text', value=st.session_state.current_text, key=f'text_area_{st.session_state.current_page_index}', height=400) module-attribute
image_directory = st.sidebar.text_input('Image Directory', value='./images') module-attribute
ocr_text_directory = st.sidebar.text_input('OCR Text Directory', value='./ocr_text') module-attribute
pages = st.session_state.pages module-attribute
save_path = os.path.join(ocr_text_directory, 'updated_ocr.xml') module-attribute
tree = st.session_state.tree module-attribute
uploaded_image_file = st.sidebar.file_uploader('Upload an Image', type=['jpg', 'jpeg', 'png', 'pdf']) module-attribute
uploaded_text_file = st.sidebar.file_uploader('Upload OCR Text File', type=['xml']) module-attribute
extract_pages(tree)

Extract page data from the XML tree.

Parameters:

Name Type Description Default
tree ElementTree

Parsed XML tree.

required

Returns:

Name Type Description
list

A list of dictionaries containing 'number' and 'text' for each page.

Source code in src/tnh_scholar/ocr_processing/ocr_editor.py
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
def extract_pages(tree):
    """
    Extract page data from the XML tree.

    Args:
        tree (etree.ElementTree): Parsed XML tree.

    Returns:
        list: A list of dictionaries containing 'number' and 'text' for each page.
    """
    pages = []
    for page in tree.xpath("//page"):
        page_number = page.get("page")
        ocr_text = page.text.strip() if page.text else ""
        pages.append({"number": page_number, "text": ocr_text})
    return pages
load_xml(file_obj)

Load an XML file from a file-like object.

Source code in src/tnh_scholar/ocr_processing/ocr_editor.py
28
29
30
31
32
33
34
35
36
37
def load_xml(file_obj):
    """
    Load an XML file from a file-like object.
    """
    try:
        tree = etree.parse(file_obj)  # Directly parse the file-like object
        return tree
    except etree.XMLSyntaxError as e:
        st.error(f"Error parsing XML file: {e}")
        return None
save_xml(tree, file_path)

Save the modified XML tree to a file.

Source code in src/tnh_scholar/ocr_processing/ocr_editor.py
41
42
43
44
45
46
def save_xml(tree, file_path):
    """
    Save the modified XML tree to a file.
    """
    with open(file_path, "wb") as file:
        tree.write(file, pretty_print=True, encoding="utf-8", xml_declaration=True)

ocr_processing

DEFAULT_ANNOTATION_FONT_PATH = Path('/System/Library/Fonts/Supplemental/Arial.ttf') module-attribute
DEFAULT_ANNOTATION_FONT_SIZE = 12 module-attribute
DEFAULT_ANNOTATION_LANGUAGE_HINTS = ['vi'] module-attribute
DEFAULT_ANNOTATION_METHOD = 'DOCUMENT_TEXT_DETECTION' module-attribute
DEFAULT_ANNOTATION_OFFSET = 2 module-attribute
logger = logging.getLogger('ocr_processing') module-attribute
PDFParseWarning

Bases: Warning

Custom warning class for PDF parsing issues. Encapsulates minimal logic for displaying warnings with a custom format.

Source code in src/tnh_scholar/ocr_processing/ocr_processing.py
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
class PDFParseWarning(Warning):
    """
    Custom warning class for PDF parsing issues.
    Encapsulates minimal logic for displaying warnings with a custom format.
    """

    @staticmethod
    def warn(message: str):
        """
        Display a warning message with custom formatting.

        Parameters:
            message (str): The warning message to display.
        """
        formatted_message = f"\033[93mPDFParseWarning: {message}\033[0m"
        print(formatted_message)  # Simply prints the warning
warn(message) staticmethod

Display a warning message with custom formatting.

Parameters:

Name Type Description Default
message str

The warning message to display.

required
Source code in src/tnh_scholar/ocr_processing/ocr_processing.py
30
31
32
33
34
35
36
37
38
39
@staticmethod
def warn(message: str):
    """
    Display a warning message with custom formatting.

    Parameters:
        message (str): The warning message to display.
    """
    formatted_message = f"\033[93mPDFParseWarning: {message}\033[0m"
    print(formatted_message)  # Simply prints the warning
annotate_image_with_text(image, text_annotations, annotation_font_path, font_size=12)

Annotates a PIL image with bounding boxes and text descriptions from OCR results.

Parameters:

Name Type Description Default
pil_image Image

The input PIL image to annotate.

required
text_annotations List[EntityAnnotation]

OCR results containing bounding boxes and text.

required
annotation_font_path str

Path to the font file for text annotations.

required
font_size int

Font size for text annotations.

12

Returns:

Type Description
Image

Image.Image: The annotated PIL image.

Raises:

Type Description
ValueError

If the input image is None.

IOError

If the font file cannot be loaded.

Exception

For any other unexpected errors.

Source code in src/tnh_scholar/ocr_processing/ocr_processing.py
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
def annotate_image_with_text(
    image: Image.Image,
    text_annotations: List[EntityAnnotation],
    annotation_font_path: str,
    font_size: int = 12,
) -> Image.Image:
    """
    Annotates a PIL image with bounding boxes and text descriptions from OCR results.

    Parameters:
        pil_image (Image.Image): The input PIL image to annotate.
        text_annotations (List[EntityAnnotation]): OCR results containing bounding boxes and text.
        annotation_font_path (str): Path to the font file for text annotations.
        font_size (int): Font size for text annotations.

    Returns:
        Image.Image: The annotated PIL image.

    Raises:
        ValueError: If the input image is None.
        IOError: If the font file cannot be loaded.
        Exception: For any other unexpected errors.
    """
    if image is None:
        raise ValueError("The input image is None.")

    try:
        font = ImageFont.truetype(annotation_font_path, font_size)
    except IOError as e:
        raise IOError(f"Failed to load the font from '{annotation_font_path}': {e}")

    draw = ImageDraw.Draw(image)

    try:
        for i, text_obj in enumerate(text_annotations):
            vertices = [
                (vertex.x, vertex.y) for vertex in text_obj.bounding_poly.vertices
            ]
            if (
                len(vertices) == 4
            ):  # Ensure there are exactly 4 vertices for a rectangle
                # Draw the bounding box
                draw.polygon(vertices, outline="red", width=2)

                # Skip the first bounding box (whole text region)
                if i > 0:
                    # Offset the text position slightly for clarity
                    text_position = (vertices[0][0] + 2, vertices[0][1] + 2)
                    draw.text(
                        text_position, text_obj.description, fill="red", font=font
                    )

    except AttributeError as e:
        raise ValueError(f"Invalid text annotation structure: {e}")
    except Exception as e:
        raise Exception(f"An error occurred during image annotation: {e}")

    return image
build_processed_pdf(pdf_path, client, preprocessor=None, annotation_font_path=DEFAULT_ANNOTATION_FONT_PATH)

Processes a PDF document, extracting text, word locations, annotated images, and unannotated images.

Parameters:

Name Type Description Default
pdf_path Path

Path to the PDF file.

required
client ImageAnnotatorClient

Google Vision API client for text detection.

required
annotation_font_path Path

Path to the font file for annotations.

DEFAULT_ANNOTATION_FONT_PATH

Returns:

Type Description
Tuple[List[str], List[List[EntityAnnotation]], List[Image], List[Image]]

Tuple[List[str], List[List[vision.EntityAnnotation]], List[Image.Image], List[Image.Image]]: - List of extracted full-page texts (one entry per page). - List of word locations (list of vision.EntityAnnotation objects for each page). - List of annotated images (with bounding boxes and text annotations). - List of unannotated images (raw page images).

Raises:

Type Description
FileNotFoundError

If the specified PDF file does not exist.

ValueError

If the PDF file is invalid or contains no pages.

Exception

For any unexpected errors during processing.

Example

from pathlib import Path from google.cloud import vision pdf_path = Path("/path/to/example.pdf") font_path = Path("/path/to/fonts/Arial.ttf") client = vision.ImageAnnotatorClient() try: text_pages, word_locations_list, annotated_images, unannotated_images = build_processed_pdf( pdf_path, client, font_path ) print(f"Processed {len(text_pages)} pages successfully!") except Exception as e: print(f"Error processing PDF: {e}")

Source code in src/tnh_scholar/ocr_processing/ocr_processing.py
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
def build_processed_pdf(
    pdf_path: Path,
    client: vision.ImageAnnotatorClient,
    preprocessor: Callable = None,
    annotation_font_path: Path = DEFAULT_ANNOTATION_FONT_PATH,
) -> Tuple[
    List[str], List[List[vision.EntityAnnotation]], List[Image.Image], List[Image.Image]
]:
    """
    Processes a PDF document, extracting text, word locations, annotated images, and unannotated images.

    Parameters:
        pdf_path (Path): Path to the PDF file.
        client (vision.ImageAnnotatorClient): Google Vision API client for text detection.
        annotation_font_path (Path): Path to the font file for annotations.

    Returns:
        Tuple[List[str], List[List[vision.EntityAnnotation]], List[Image.Image], List[Image.Image]]:
            - List of extracted full-page texts (one entry per page).
            - List of word locations (list of `vision.EntityAnnotation` objects for each page).
            - List of annotated images (with bounding boxes and text annotations).
            - List of unannotated images (raw page images).

    Raises:
        FileNotFoundError: If the specified PDF file does not exist.
        ValueError: If the PDF file is invalid or contains no pages.
        Exception: For any unexpected errors during processing.

    Example:
        >>> from pathlib import Path
        >>> from google.cloud import vision
        >>> pdf_path = Path("/path/to/example.pdf")
        >>> font_path = Path("/path/to/fonts/Arial.ttf")
        >>> client = vision.ImageAnnotatorClient()
        >>> try:
        >>>     text_pages, word_locations_list, annotated_images, unannotated_images = build_processed_pdf(
        >>>         pdf_path, client, font_path
        >>>     )
        >>>     print(f"Processed {len(text_pages)} pages successfully!")
        >>> except Exception as e:
        >>>     print(f"Error processing PDF: {e}")
    """
    try:
        doc = load_pdf_pages(pdf_path)
    except FileNotFoundError as fnf_error:
        raise FileNotFoundError(f"Error loading PDF: {fnf_error}")
    except ValueError as ve:
        raise ValueError(f"Invalid PDF file: {ve}")
    except Exception as e:
        raise Exception(f"An unexpected error occurred while loading the PDF: {e}")

    if doc.page_count == 0:
        raise ValueError(f"The PDF file '{pdf_path}' contains no pages.")

    logger.info(f"Processing file with {doc.page_count} pages:\n\t{pdf_path}")

    text_pages = []
    word_locations_list = []
    annotated_images = []
    unannotated_images = []
    first_page_dimensions = None

    for page_num in range(doc.page_count):
        logger.info(f"Processing page {page_num + 1}/{doc.page_count}...")

        try:
            page = doc.load_page(page_num)
            (
                full_page_text,
                word_locations,
                annotated_image,
                unannotated_image,
                page_dimensions,
            ) = process_page(page, client, annotation_font_path, preprocessor)

            if full_page_text:  # this is not an empty page

                if page_num == 0:  # save first page info
                    first_page_dimensions = page_dimensions
                elif (
                    page_dimensions != first_page_dimensions
                ):  # verify page dimensions are consistent
                    PDFParseWarning.warn(
                        f"Page {page_num + 1} has different dimensions than page 1."
                        f"({page_dimensions}) compared to the first page: ({first_page_dimensions})."
                    )

                text_pages.append(full_page_text)
                word_locations_list.append(word_locations)
                annotated_images.append(annotated_image)
                unannotated_images.append(unannotated_image)
            else:
                PDFParseWarning.warn(
                    f"Page {page_num + 1} empty, added empty datastructures...\n"
                    # f"  (Note that total document length will be reduced.)"
                )

        except ValueError as ve:
            print(f"ValueError on page {page_num + 1}: {ve}")
        except OSError as oe:
            print(f"OSError on page {page_num + 1}: {oe}")
        except Exception as e:
            print(f"Unexpected error on page {page_num + 1}: {e}")

    print(f"page dimensions: {page_dimensions}")
    return text_pages, word_locations_list, annotated_images, unannotated_images
deserialize_entity_annotations_from_json(data)

Deserializes JSON data into a nested list of EntityAnnotation objects.

Parameters:

Name Type Description Default
data str

The JSON string containing serialized annotations.

required

Returns:

Type Description
List[List[EntityAnnotation]]

List[List[EntityAnnotation]]: The reconstructed nested list of EntityAnnotation objects.

Source code in src/tnh_scholar/ocr_processing/ocr_processing.py
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
def deserialize_entity_annotations_from_json(data: str) -> List[List[EntityAnnotation]]:
    """
    Deserializes JSON data into a nested list of EntityAnnotation objects.

    Parameters:
        data (str): The JSON string containing serialized annotations.

    Returns:
        List[List[EntityAnnotation]]: The reconstructed nested list of EntityAnnotation objects.
    """
    serialized_data = json.loads(data)
    deserialized_data = []

    for serialized_page in serialized_data:
        page_annotations = [
            EntityAnnotation.deserialize(base64.b64decode(serialized_annotation))
            for serialized_annotation in serialized_page
        ]
        deserialized_data.append(page_annotations)

    return deserialized_data
extract_image_from_page(page)

Extracts the first image from the given PDF page and returns it as a PIL Image.

Parameters:

Name Type Description Default
page Page

The PDF page object.

required

Returns:

Type Description
Image

Image.Image: The first image on the page as a Pillow Image object.

Raises:

Type Description
ValueError

If no images are found on the page or the image data is incomplete.

Exception

For unexpected errors during image extraction.

Example

import fitz from PIL import Image doc = fitz.open("/path/to/document.pdf") page = doc.load_page(0) # Load the first page try: image = extract_image_from_page(page) image.show() # Display the image except Exception as e: print(f"Error extracting image: {e}")

Source code in src/tnh_scholar/ocr_processing/ocr_processing.py
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
def extract_image_from_page(page: fitz.Page) -> Image.Image:
    """
    Extracts the first image from the given PDF page and returns it as a PIL Image.

    Parameters:
        page (fitz.Page): The PDF page object.

    Returns:
        Image.Image: The first image on the page as a Pillow Image object.

    Raises:
        ValueError: If no images are found on the page or the image data is incomplete.
        Exception: For unexpected errors during image extraction.

    Example:
        >>> import fitz
        >>> from PIL import Image
        >>> doc = fitz.open("/path/to/document.pdf")
        >>> page = doc.load_page(0)  # Load the first page
        >>> try:
        >>>     image = extract_image_from_page(page)
        >>>     image.show()  # Display the image
        >>> except Exception as e:
        >>>     print(f"Error extracting image: {e}")
    """
    try:
        # Get images from the page
        images = page.get_images(full=True)
        if not images:
            raise ValueError("No images found on the page.")

        # Extract the first image reference
        xref = images[0][0]  # Get the first image's xref
        base_image = page.parent.extract_image(xref)

        # Validate the extracted image data
        if (
            "image" not in base_image
            or "width" not in base_image
            or "height" not in base_image
        ):
            raise ValueError("The extracted image data is incomplete.")

        # Convert the raw image bytes into a Pillow image
        image_bytes = base_image["image"]
        pil_image = Image.open(io.BytesIO(image_bytes)).convert("RGB")

        return pil_image

    except ValueError as ve:
        raise ve  # Re-raise for calling functions to handle
    except Exception as e:
        raise Exception(f"An unexpected error occurred during image extraction: {e}")
get_page_dimensions(page)

Extracts the width and height of a single PDF page in both inches and pixels.

Parameters:

Name Type Description Default
page Page

A single PDF page object from PyMuPDF.

required

Returns:

Name Type Description
dict dict

A dictionary containing the width and height of the page in inches and pixels.

Source code in src/tnh_scholar/ocr_processing/ocr_processing.py
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
def get_page_dimensions(page: fitz.Page) -> dict:
    """
    Extracts the width and height of a single PDF page in both inches and pixels.

    Args:
        page (fitz.Page): A single PDF page object from PyMuPDF.

    Returns:
        dict: A dictionary containing the width and height of the page in inches and pixels.
    """
    # Get page dimensions in points and convert to inches
    page_width_pts, page_height_pts = page.rect.width, page.rect.height
    page_width_in = page_width_pts / 72  # Convert points to inches
    page_height_in = page_height_pts / 72

    # Extract the first image on the page (if any) to get pixel dimensions
    images = page.get_images(full=True)
    if images:
        xref = images[0][0]
        base_image = page.parent.extract_image(xref)
        width_px = base_image["width"]
        height_px = base_image["height"]
    else:
        width_px, height_px = None, None  # No image found on the page

    # Return dimensions
    return {
        "width_in": page_width_in,
        "height_in": page_height_in,
        "width_px": width_px,
        "height_px": height_px,
    }
load_pdf_pages(pdf_path)

Opens the PDF document and returns the fitz Document object.

Parameters:

Name Type Description Default
pdf_path Path

The path to the PDF file.

required

Returns:

Type Description
Document

fitz.Document: The loaded PDF document.

Raises:

Type Description
FileNotFoundError

If the specified file does not exist.

ValueError

If the file is not a valid PDF document.

Exception

For any unexpected error.

Example

from pathlib import Path pdf_path = Path("/path/to/example.pdf") try: pdf_doc = load_pdf_pages(pdf_path) print(f"PDF contains {pdf_doc.page_count} pages.") except Exception as e: print(f"Error loading PDF: {e}")

Source code in src/tnh_scholar/ocr_processing/ocr_processing.py
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
def load_pdf_pages(pdf_path: Path) -> fitz.Document:
    """
    Opens the PDF document and returns the fitz Document object.

    Parameters:
        pdf_path (Path): The path to the PDF file.

    Returns:
        fitz.Document: The loaded PDF document.

    Raises:
        FileNotFoundError: If the specified file does not exist.
        ValueError: If the file is not a valid PDF document.
        Exception: For any unexpected error.

    Example:
        >>> from pathlib import Path
        >>> pdf_path = Path("/path/to/example.pdf")
        >>> try:
        >>>     pdf_doc = load_pdf_pages(pdf_path)
        >>>     print(f"PDF contains {pdf_doc.page_count} pages.")
        >>> except Exception as e:
        >>>     print(f"Error loading PDF: {e}")
    """
    if not pdf_path.exists():
        raise FileNotFoundError(f"The file '{pdf_path}' does not exist.")

    if not pdf_path.suffix.lower() == ".pdf":
        raise ValueError(
            f"The file '{pdf_path}' is not a valid PDF document (expected '.pdf')."
        )

    try:
        return fitz.open(str(pdf_path))  # PyMuPDF expects a string path
    except Exception as e:
        raise Exception(f"An unexpected error occurred while opening the PDF: {e}")
load_processed_PDF_data(base_path)

Loads processed PDF data from files using metadata for file references.

Parameters:

Name Type Description Default
output_dir Path

Directory where the data is stored (as a Path object).

required
base_name str

Base name of the processed directory.

required

Returns:

Type Description
Tuple[List[str], List[List[EntityAnnotation]], List[Image], List[Image]]

Tuple[List[str], List[List[EntityAnnotation]], List[Image.Image], List[Image.Image]]: - Loaded text pages. - Word locations (list of EntityAnnotation objects for each page). - Annotated images. - Unannotated images.

Raises:

Type Description
FileNotFoundError

If any required files are missing.

ValueError

If the metadata file is incomplete or invalid.

Source code in src/tnh_scholar/ocr_processing/ocr_processing.py
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
def load_processed_PDF_data(
    base_path: Path,
) -> Tuple[
    List[str], List[List[EntityAnnotation]], List[Image.Image], List[Image.Image]
]:
    """
    Loads processed PDF data from files using metadata for file references.

    Parameters:
        output_dir (Path): Directory where the data is stored (as a Path object).
        base_name (str): Base name of the processed directory.

    Returns:
        Tuple[List[str], List[List[EntityAnnotation]], List[Image.Image], List[Image.Image]]:
            - Loaded text pages.
            - Word locations (list of `EntityAnnotation` objects for each page).
            - Annotated images.
            - Unannotated images.

    Raises:
        FileNotFoundError: If any required files are missing.
        ValueError: If the metadata file is incomplete or invalid.
    """
    metadata_file = base_path / "metadata.json"

    # Load metadata
    try:
        with metadata_file.open("r", encoding="utf-8") as f:
            metadata = json.load(f)
    except FileNotFoundError:
        raise FileNotFoundError(f"Metadata file '{metadata_file}' not found.")
    except json.JSONDecodeError as e:
        raise ValueError(f"Invalid metadata file format: {e}")

    # Extract file paths from metadata
    text_pages_file = base_path / metadata.get("files", {}).get(
        "text_pages", "text_pages.json"
    )
    word_locations_file = base_path / metadata.get("files", {}).get(
        "word_locations", "word_locations.json"
    )
    images_dir = Path(metadata.get("images_directory", base_path / "images"))

    # Validate file paths
    if not text_pages_file.exists():
        raise FileNotFoundError(f"Text pages file '{text_pages_file}' not found.")
    if not word_locations_file.exists():
        raise FileNotFoundError(
            f"Word locations file '{word_locations_file}' not found."
        )
    if not images_dir.exists() or not images_dir.is_dir():
        raise FileNotFoundError(f"Images directory '{images_dir}' not found.")

    # Load text pages
    with text_pages_file.open("r", encoding="utf-8") as f:
        text_pages = json.load(f)

    # Load word locations
    with word_locations_file.open("r", encoding="utf-8") as f:
        serialized_word_locations = f.read()
        word_locations = deserialize_entity_annotations_from_json(
            serialized_word_locations
        )

    # Load images
    annotated_images = []
    unannotated_images = []
    for file in sorted(
        images_dir.iterdir()
    ):  # Iterate over files in the images directory
        if file.name.startswith("annotated_page_") and file.suffix == ".png":
            annotated_images.append(Image.open(file))
        elif file.name.startswith("unannotated_page_") and file.suffix == ".png":
            unannotated_images.append(Image.open(file))

    # Ensure images were loaded correctly
    if not annotated_images or not unannotated_images:
        raise ValueError(f"No images found in the directory '{images_dir}'.")

    return text_pages, word_locations, annotated_images, unannotated_images
make_image_preprocess_mask(mask_height)

Creates a preprocessing function that masks a specified height at the bottom of the image.

Parameters:

Name Type Description Default
mask_height float

The proportion of the image height to mask at the bottom (0.0 to 1.0).

required

Returns:

Type Description
Callable[[Image, int], Image]

Callable[[Image.Image, int], Image.Image]: A preprocessing function that takes an image

Callable[[Image, int], Image]

and page number as input and returns the processed image.

Source code in src/tnh_scholar/ocr_processing/ocr_processing.py
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
def make_image_preprocess_mask(
    mask_height: float,
) -> Callable[[Image.Image, int], Image.Image]:
    """
    Creates a preprocessing function that masks a specified height at the bottom of the image.

    Parameters:
        mask_height (float): The proportion of the image height to mask at the bottom (0.0 to 1.0).

    Returns:
        Callable[[Image.Image, int], Image.Image]: A preprocessing function that takes an image
        and page number as input and returns the processed image.
    """

    def pre_process_image(image: Image.Image, page_number: int) -> Image.Image:
        """
        Preprocesses the image by masking the bottom region or performing other preprocessing steps.

        Parameters:
            image (Image.Image): The input image as a Pillow object.
            page_number (int): The page number of the image (useful for conditional preprocessing).

        Returns:
            Image.Image: The preprocessed image.
        """

        if page_number > 0:  # don't apply mask to cover page.
            # Make a copy of the image to avoid modifying the original
            draw = ImageDraw.Draw(image)

            # Get image dimensions
            width, height = image.size

            # Mask the bottom region based on the specified height proportion
            mask_pixels = int(height * mask_height)
            draw.rectangle([(0, height - mask_pixels), (width, height)], fill="black")

        return image

    return pre_process_image
pil_to_bytes(image, format='PNG')

Converts a Pillow image to raw bytes.

Parameters:

Name Type Description Default
image Image

The Pillow image object to convert.

required
format str

The format to save the image as (e.g., "PNG", "JPEG"). Default is "PNG".

'PNG'

Returns:

Name Type Description
bytes bytes

The raw bytes of the image.

Source code in src/tnh_scholar/ocr_processing/ocr_processing.py
42
43
44
45
46
47
48
49
50
51
52
53
54
55
def pil_to_bytes(image: Image.Image, format: str = "PNG") -> bytes:
    """
    Converts a Pillow image to raw bytes.

    Parameters:
        image (Image.Image): The Pillow image object to convert.
        format (str): The format to save the image as (e.g., "PNG", "JPEG"). Default is "PNG".

    Returns:
        bytes: The raw bytes of the image.
    """
    with io.BytesIO() as output:
        image.save(output, format=format)
        return output.getvalue()
process_page(page, client, annotation_font_path, preprocessor=None)

Processes a single PDF page, extracting text, word locations, and annotated images.

Parameters:

Name Type Description Default
page Page

The PDF page object.

required
client ImageAnnotatorClient

Google Vision API client for text detection.

required
pre_processor Callable[[Image, int], Image]

Preprocessing function for the image.

required
annotation_font_path str

Path to the font file for annotations.

required

Returns:

Type Description
Tuple[str, List[EntityAnnotation], Image, Image, dict]

Tuple[str, List[vision.EntityAnnotation], Image.Image, Image.Image, dict]: - Full page text (str) - Word locations (List of vision.EntityAnnotation) - Annotated image (Pillow Image object) - Original unprocessed image (Pillow Image object) - Page dimensions (dict)

Source code in src/tnh_scholar/ocr_processing/ocr_processing.py
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
def process_page(
    page: fitz.Page,
    client: vision.ImageAnnotatorClient,
    annotation_font_path: str,
    preprocessor: Callable[[Image.Image, int], Image.Image] = None,
) -> Tuple[str, List[vision.EntityAnnotation], Image.Image, Image.Image, dict]:
    """
    Processes a single PDF page, extracting text, word locations, and annotated images.

    Parameters:
        page (fitz.Page): The PDF page object.
        client (vision.ImageAnnotatorClient): Google Vision API client for text detection.
        pre_processor (Callable[[Image.Image, int], Image.Image]): Preprocessing function for the image.
        annotation_font_path (str): Path to the font file for annotations.

    Returns:
        Tuple[str, List[vision.EntityAnnotation], Image.Image, Image.Image, dict]:
            - Full page text (str)
            - Word locations (List of vision.EntityAnnotation)
            - Annotated image (Pillow Image object)
            - Original unprocessed image (Pillow Image object)
            - Page dimensions (dict)
    """
    # Extract the original image from the PDF page
    original_image = extract_image_from_page(page)

    # Make a copy of the original image for processing
    processed_image = original_image.copy()

    # Apply the preprocessing function (if provided)
    if preprocessor:
        # print("preprocessing...") # debug
        processed_image = preprocessor(processed_image, page.number)
        # processed_image.show() # debug

    # Annotate the processed image using the Vision API
    response = process_single_image(processed_image, client)

    if response:
        text_annotations = response.text_annotations
        # Extract full text and word locations
        full_page_text = text_annotations[0].description if text_annotations else ""
        word_locations = text_annotations[1:] if len(text_annotations) > 1 else []
    else:
        # return empty data
        full_page_text = ""
        word_locations = [EntityAnnotation()]
        text_annotations = [
            EntityAnnotation()
        ]  # create empty data structures to allow storing to proceed.

    # Create an annotated image with bounding boxes and labels
    annotated_image = annotate_image_with_text(
        processed_image, text_annotations, annotation_font_path
    )

    # Get page dimensions (from the original PDF page, not the image)
    page_dimensions = get_page_dimensions(page)

    return (
        full_page_text,
        word_locations,
        annotated_image,
        original_image,
        page_dimensions,
    )
process_single_image(image, client, feature_type=DEFAULT_ANNOTATION_METHOD, language_hints=DEFAULT_ANNOTATION_LANGUAGE_HINTS)

Processes a single image with the Google Vision API and returns text annotations.

Parameters:

Name Type Description Default
image Image

The preprocessed Pillow image object.

required
client ImageAnnotatorClient

Google Vision API client for text detection.

required
feature_type str

Type of text detection to use ('TEXT_DETECTION' or 'DOCUMENT_TEXT_DETECTION').

DEFAULT_ANNOTATION_METHOD
language_hints List

Language hints for OCR.

DEFAULT_ANNOTATION_LANGUAGE_HINTS

Returns:

Type Description
List[EntityAnnotation]

List[vision.EntityAnnotation]: Text annotations from the Vision API response.

Raises:

Type Description
ValueError

If no text is detected.

Source code in src/tnh_scholar/ocr_processing/ocr_processing.py
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
def process_single_image(
    image: Image.Image,
    client: vision.ImageAnnotatorClient,
    feature_type: str = DEFAULT_ANNOTATION_METHOD,
    language_hints: List = DEFAULT_ANNOTATION_LANGUAGE_HINTS,
) -> List[vision.EntityAnnotation]:
    """
    Processes a single image with the Google Vision API and returns text annotations.

    Parameters:
        image (Image.Image): The preprocessed Pillow image object.
        client (vision.ImageAnnotatorClient): Google Vision API client for text detection.
        feature_type (str): Type of text detection to use ('TEXT_DETECTION' or 'DOCUMENT_TEXT_DETECTION').
        language_hints (List): Language hints for OCR.

    Returns:
        List[vision.EntityAnnotation]: Text annotations from the Vision API response.

    Raises:
        ValueError: If no text is detected.
    """
    # Convert the Pillow image to bytes
    image_bytes = pil_to_bytes(image, format="PNG")

    # Map feature type
    feature_map = {
        "TEXT_DETECTION": vision.Feature.Type.TEXT_DETECTION,
        "DOCUMENT_TEXT_DETECTION": vision.Feature.Type.DOCUMENT_TEXT_DETECTION,
    }
    if feature_type not in feature_map:
        raise ValueError(
            f"Invalid feature type '{feature_type}'. Use 'TEXT_DETECTION' or 'DOCUMENT_TEXT_DETECTION'."
        )

    # Prepare Vision API request
    vision_image = vision.Image(content=image_bytes)
    features = [vision.Feature(type=feature_map[feature_type])]
    image_context = vision.ImageContext(language_hints=language_hints)

    # Make the API call
    response = client.annotate_image(
        {"image": vision_image, "features": features, "image_context": image_context}
    )

    return response
save_processed_pdf_data(output_dir, journal_name, text_pages, word_locations, annotated_images, unannotated_images)

Saves processed PDF data to files for later reloading.

Parameters:

Name Type Description Default
output_dir Path

Directory to save the data (as a Path object).

required
base_name str

Base name for the output directory (usually the PDF name without extension).

required
text_pages List[str]

Extracted full-page text.

required
word_locations List[List[EntityAnnotation]]

Word locations and annotations from Vision API.

required
annotated_images List[Image]

Annotated images with bounding boxes.

required
unannotated_images List[Image]

Raw unannotated images.

required

Returns:

Type Description

None

Source code in src/tnh_scholar/ocr_processing/ocr_processing.py
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
def save_processed_pdf_data(
    output_dir: Path,
    journal_name: str,
    text_pages: List[str],
    word_locations: List[List[EntityAnnotation]],
    annotated_images: List[Image.Image],
    unannotated_images: List[Image.Image],
):
    """
    Saves processed PDF data to files for later reloading.

    Parameters:
        output_dir (Path): Directory to save the data (as a Path object).
        base_name (str): Base name for the output directory (usually the PDF name without extension).
        text_pages (List[str]): Extracted full-page text.
        word_locations (List[List[EntityAnnotation]]): Word locations and annotations from Vision API.
        annotated_images (List[PIL.Image.Image]): Annotated images with bounding boxes.
        unannotated_images (List[PIL.Image.Image]): Raw unannotated images.

    Returns:
        None
    """
    # Create output directories
    base_path = output_dir / journal_name / "ocr_data"
    images_dir = base_path / "images"

    base_path.mkdir(parents=True, exist_ok=True)
    images_dir.mkdir(parents=True, exist_ok=True)

    # Save text data
    text_pages_file = base_path / "text_pages.json"
    with text_pages_file.open("w", encoding="utf-8") as f:
        json.dump(text_pages, f, indent=4, ensure_ascii=False)

    # Save word locations as JSON
    word_locations_file = base_path / "word_locations.json"
    serialized_word_locations = serialize_entity_annotations_to_json(word_locations)
    with word_locations_file.open("w", encoding="utf-8") as f:
        f.write(serialized_word_locations)

    # Save images
    for i, annotated_image in enumerate(annotated_images):
        annotated_file = images_dir / f"annotated_page_{i + 1}.png"
        annotated_image.save(annotated_file)
    for i, unannotated_image in enumerate(unannotated_images):
        unannotated_file = images_dir / f"unannotated_page_{i + 1}.png"
        unannotated_image.save(unannotated_file)

    # Save metadata
    metadata = {
        "source_pdf": journal_name,
        "page_count": len(text_pages),
        "images_directory": str(
            images_dir
        ),  # Convert Path to string for JSON serialization
        "files": {
            "text_pages": "text_pages.json",
            "word_locations": "word_locations.json",
        },
    }
    metadata_file = base_path / "metadata.json"
    with metadata_file.open("w", encoding="utf-8") as f:
        json.dump(metadata, f, indent=4)

    print(f"Processed data saved in: {base_path}")
serialize_entity_annotations_to_json(annotations)

Serializes a nested list of EntityAnnotation objects into a JSON-compatible format using Base64 encoding.

Parameters:

Name Type Description Default
annotations List[List[EntityAnnotation]]

The nested list of EntityAnnotation objects.

required

Returns:

Name Type Description
str str

The serialized data in JSON format as a string.

Source code in src/tnh_scholar/ocr_processing/ocr_processing.py
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
def serialize_entity_annotations_to_json(
    annotations: List[List[EntityAnnotation]],
) -> str:
    """
    Serializes a nested list of EntityAnnotation objects into a JSON-compatible format using Base64 encoding.

    Parameters:
        annotations (List[List[EntityAnnotation]]): The nested list of EntityAnnotation objects.

    Returns:
        str: The serialized data in JSON format as a string.
    """
    serialized_data = []
    for page_annotations in annotations:
        serialized_page = [
            base64.b64encode(annotation.SerializeToString()).decode("utf-8")
            for annotation in page_annotations
        ]
        serialized_data.append(serialized_page)

    # Convert to a JSON string
    return json.dumps(serialized_data, indent=4)
start_image_annotator_client(credentials_file=None, api_endpoint='vision.googleapis.com', timeout=(10, 30), enable_logging=False)

Starts and returns a Google Vision API ImageAnnotatorClient with optional configuration.

Parameters:

Name Type Description Default
credentials_file str

Path to the credentials JSON file. If None, uses the default environment variable.

None
api_endpoint str

Custom API endpoint for the Vision API. Default is the global endpoint.

'vision.googleapis.com'
timeout Tuple[int, int]

Connection and read timeouts in seconds. Default is (10, 30).

(10, 30)
enable_logging bool

Enable detailed logging for debugging. Default is False.

False

Returns:

Type Description
ImageAnnotatorClient

vision.ImageAnnotatorClient: Configured Vision API client.

Raises:

Type Description
FileNotFoundError

If the specified credentials file is not found.

Exception

For unexpected errors during client setup.

Example

client = start_image_annotator_client( credentials_file="/path/to/credentials.json", api_endpoint="vision.googleapis.com", timeout=(10, 30), enable_logging=True ) print("Google Vision API client initialized.")

Source code in src/tnh_scholar/ocr_processing/ocr_processing.py
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
def start_image_annotator_client(
    credentials_file: str = None,
    api_endpoint: str = "vision.googleapis.com",
    timeout: Tuple[int, int] = (10, 30),
    enable_logging: bool = False,
) -> vision.ImageAnnotatorClient:
    """
    Starts and returns a Google Vision API ImageAnnotatorClient with optional configuration.

    Parameters:
        credentials_file (str): Path to the credentials JSON file. If None, uses the default environment variable.
        api_endpoint (str): Custom API endpoint for the Vision API. Default is the global endpoint.
        timeout (Tuple[int, int]): Connection and read timeouts in seconds. Default is (10, 30).
        enable_logging (bool): Enable detailed logging for debugging. Default is False.

    Returns:
        vision.ImageAnnotatorClient: Configured Vision API client.

    Raises:
        FileNotFoundError: If the specified credentials file is not found.
        Exception: For unexpected errors during client setup.

    Example:
        >>> client = start_image_annotator_client(
        >>>     credentials_file="/path/to/credentials.json",
        >>>     api_endpoint="vision.googleapis.com",
        >>>     timeout=(10, 30),
        >>>     enable_logging=True
        >>> )
        >>> print("Google Vision API client initialized.")
    """
    try:
        # Set up credentials
        if credentials_file:
            if not os.path.exists(credentials_file):
                raise FileNotFoundError(
                    f"Credentials file '{credentials_file}' not found."
                )
            os.environ["GOOGLE_APPLICATION_CREDENTIALS"] = credentials_file

        # Configure client options
        client_options = {"api_endpoint": api_endpoint}
        client = vision.ImageAnnotatorClient(client_options=client_options)

        # Optionally enable logging
        if enable_logging:
            print(f"Vision API Client started with endpoint: {api_endpoint}")
            print(f"Timeout settings: Connect={timeout[0]}s, Read={timeout[1]}s")

        return client

    except Exception as e:
        raise Exception(f"Failed to initialize ImageAnnotatorClient: {e}")

openai_interface

openai_interface

DEBUG_DISPLAY_BUFFER = 1000 module-attribute
DEFAULT_MAX_BATCH_RETRY = 60 module-attribute
DEFAULT_MODEL_SETTINGS = {'gpt-4o': {'max_tokens': 16000, 'context_limit': 128000, 'temperature': 1.0}, 'gpt-3.5-turbo': {'max_tokens': 4096, 'context_limit': 16384, 'temperature': 1.0}, 'gpt-4o-mini': {'max_tokens': 16000, 'context_limit': 128000, 'temperature': 1.0}} module-attribute
MAX_BATCH_LIST = 30 module-attribute
OPEN_AI_DEFAULT_MODEL = 'gpt-4o' module-attribute
logger = get_child_logger(__name__) module-attribute
open_ai_encoding = tiktoken.encoding_for_model(OPEN_AI_DEFAULT_MODEL) module-attribute
open_ai_model_settings = DEFAULT_MODEL_SETTINGS module-attribute
ClientNotInitializedError

Bases: Exception

Exception raised when the OpenAI client is not initialized.

Source code in src/tnh_scholar/openai_interface/openai_interface.py
68
69
70
71
class ClientNotInitializedError(Exception):
    """Exception raised when the OpenAI client is not initialized."""

    pass
OpenAIClient

Singleton class for managing the OpenAI client.

Source code in src/tnh_scholar/openai_interface/openai_interface.py
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
class OpenAIClient:
    """Singleton class for managing the OpenAI client."""

    _instance = None

    def __init__(self, api_key: str):
        """Initialize the OpenAI client."""
        self.client = OpenAI(api_key=api_key)

    @classmethod
    def get_instance(cls):
        """
        Get or initialize the OpenAI client.

        Returns:
            OpenAI: The singleton OpenAI client instance.
        """
        if cls._instance is None:
            # Load the .env file
            load_dotenv()

            if api_key := os.getenv("OPENAI_API_KEY"):
                # Initialize the singleton instance
                cls._instance = cls(api_key)
            else:
                raise ValueError(
                    "API key not found. Set it in the .env file with the key 'OPENAI_API_KEY'."
                )

        return cls._instance.client
client = OpenAI(api_key=api_key) instance-attribute
__init__(api_key)

Initialize the OpenAI client.

Source code in src/tnh_scholar/openai_interface/openai_interface.py
41
42
43
def __init__(self, api_key: str):
    """Initialize the OpenAI client."""
    self.client = OpenAI(api_key=api_key)
get_instance() classmethod

Get or initialize the OpenAI client.

Returns:

Name Type Description
OpenAI

The singleton OpenAI client instance.

Source code in src/tnh_scholar/openai_interface/openai_interface.py
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
@classmethod
def get_instance(cls):
    """
    Get or initialize the OpenAI client.

    Returns:
        OpenAI: The singleton OpenAI client instance.
    """
    if cls._instance is None:
        # Load the .env file
        load_dotenv()

        if api_key := os.getenv("OPENAI_API_KEY"):
            # Initialize the singleton instance
            cls._instance = cls(api_key)
        else:
            raise ValueError(
                "API key not found. Set it in the .env file with the key 'OPENAI_API_KEY'."
            )

    return cls._instance.client
create_jsonl_file_for_batch(messages, output_file_path=None, max_token_list=None, model=OPEN_AI_DEFAULT_MODEL, tools=None, json_mode=False)

Creates a JSONL file for batch processing, with each request using the same system message, user messages, and optional function schema for function calling.

Parameters:

Name Type Description Default
messages List[str]

List of message objects to be sent for completion.

required
output_file_path str

The path where the .jsonl file will be saved.

None
model str

The model to use (default is set globally).

OPEN_AI_DEFAULT_MODEL
functions list

List of function schemas to enable function calling.

required

Returns:

Name Type Description
str

The path to the generated .jsonl file.

Source code in src/tnh_scholar/openai_interface/openai_interface.py
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
def create_jsonl_file_for_batch(
    messages: List[str],
    output_file_path: Optional[Path] = None,
    max_token_list: Optional[List[int]] = None,
    model: str = OPEN_AI_DEFAULT_MODEL,
    tools=None,
    json_mode: Optional[bool] = False,
):
    """
    Creates a JSONL file for batch processing, with each request using the same system message, user messages,
    and optional function schema for function calling.

    Args:
        messages: List of message objects to be sent for completion.
        output_file_path (str): The path where the .jsonl file will be saved.
        model (str): The model to use (default is set globally).
        functions (list, optional): List of function schemas to enable function calling.

    Returns:
        str: The path to the generated .jsonl file.
    """
    global open_ai_model_settings

    total_tokens = 0

    if not max_token_list:
        max_tokens = open_ai_model_settings[model]["max_tokens"]
        message_count = len(messages)
        max_token_list = [max_tokens] * message_count

    temperature = open_ai_model_settings[model]["temperature"]
    total_tokens = sum(max_token_list)

    if output_file_path is None:
        date_str = datetime.now().strftime("%m%d%Y")
        output_file_path = Path(f"batch_requests_{date_str}.jsonl")

    # Ensure the directory for the output file exists
    output_dir = Path(output_file_path).parent
    if not output_dir.exists():
        output_dir.mkdir(parents=True, exist_ok=True)

    requests = []
    for i, message in enumerate(messages):

        # get max_tokens
        max_tokens = max_token_list[i]

        request_obj = {
            "custom_id": f"request-{i+1}",
            "method": "POST",
            "url": "/v1/chat/completions",
            "body": {
                "model": model,
                "messages": message,
                "max_tokens": max_tokens,
                "temperature": temperature,
            },
        }
        if json_mode:
            request_obj["body"]["response_format"] = {"type": "json_object"}
        if tools:
            request_obj["body"]["tools"] = tools

        if i == 0:  # log first iteration only.
            _log_batch_creation_info(output_file_path, request_obj, total_tokens)

        requests.append(request_obj)

    # Write requests to JSONL file
    with open(output_file_path, "w") as f:
        for request in requests:
            json.dump(request, f)
            f.write("\n")

    logger.info(f"JSONL file created at: {output_file_path}")
    return output_file_path
delete_api_files(cutoff_date)

Delete all files on OpenAI's storage older than a given date at midnight.

Parameters: - cutoff_date (datetime): The cutoff date. Files older than this date will be deleted.

Source code in src/tnh_scholar/openai_interface/openai_interface.py
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
def delete_api_files(cutoff_date: datetime):
    """
    Delete all files on OpenAI's storage older than a given date at midnight.

    Parameters:
    - cutoff_date (datetime): The cutoff date. Files older than this date will be deleted.
    """
    # Set the OpenAI API key
    client = get_api_client()

    # Get a list of all files
    files = client.files.list()

    for file in files.data:
        # Parse the file creation timestamp
        file_created_at = datetime.fromtimestamp(file.created_at)
        # Check if the file is older than the cutoff date
        if file_created_at < cutoff_date:
            try:
                # Delete the file
                client.files.delete(file.id)
                print(f"Deleted file: {file.id} (created on {file_created_at})")
            except Exception as e:
                logger.error(f"Failed to delete file {file.id}: {e}", exc_info=True)
generate_messages(system_message, user_message_wrapper, data_list_to_process, log_system_message=True)
Source code in src/tnh_scholar/openai_interface/openai_interface.py
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
def generate_messages(
    system_message: str,
    user_message_wrapper: callable,
    data_list_to_process: List,
    log_system_message=True,
):
    messages = []
    for data_element in data_list_to_process:
        message_block = [
            {"role": "system", "content": system_message},
            {
                "role": "user",
                "content": user_message_wrapper(data_element),
            },
        ]
        messages.append(message_block)
    return messages
get_active_batches()

Retrieve the list of active batches using the OpenAI API.

Source code in src/tnh_scholar/openai_interface/openai_interface.py
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
def get_active_batches() -> List[Dict]:
    """
    Retrieve the list of active batches using the OpenAI API.
    """
    client = get_api_client()

    try:
        batches = client.batches.list(limit=MAX_BATCH_LIST)
        batch_list = []
        for batch in batches:
            if batch.status == "in_progress":
                batch_info = {
                    "id": batch.id,
                    "status": batch.status,
                    "created_at": batch.created_at,
                    # Add other relevant attributes as needed
                }
                batch_list.append(batch_info)
        return batch_list
    except Exception as e:
        logger.error(f"Error fetching active batches: {e}")
        return []
get_all_batch_info()

Retrieve the list of batches up to MAX_BATCH_LIST using the OpenAI API.

Source code in src/tnh_scholar/openai_interface/openai_interface.py
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
def get_all_batch_info():
    """
    Retrieve the list of batches up to MAX_BATCH_LIST using the OpenAI API.
    """
    client = get_api_client()

    try:
        batches = client.batches.list(limit=MAX_BATCH_LIST)
        batch_list = []
        for batch in batches:
            batch_info = {
                "id": batch.id,
                "status": batch.status,
                "created_at": batch.created_at,
                "output_file_id": batch.output_file_id,
                "metadata": batch.metadata,
                # Add other relevant attributes as needed
            }
            batch_list.append(batch_info)
        return batch_list
    except Exception as e:
        logger.error(f"Error fetching active batches: {e}", exc_info=True)
        return []
get_api_client()
Source code in src/tnh_scholar/openai_interface/openai_interface.py
83
84
def get_api_client():
    return OpenAIClient.get_instance()
get_batch_response(batch_id)

Retrieves the status of a batch job and returns the result if completed. Parses the JSON result file, collects the output messages, and returns them as a Python list.

Args: - batch_id : The batch_id string to retrieve status and results for.

Returns: - If completed: A list containing the message content for each response of the batch process. - If not completed: A string with the batch status.

Source code in src/tnh_scholar/openai_interface/openai_interface.py
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
def get_batch_response(batch_id: str) -> List[str]:
    """
    Retrieves the status of a batch job and returns the result if completed.
    Parses the JSON result file, collects the output messages,
    and returns them as a Python list.

    Args:
    - batch_id : The batch_id string to retrieve status and results for.

    Returns:
    - If completed: A list containing the message content for each response of the batch process.
    - If not completed: A string with the batch status.
    """
    client = get_api_client()

    # Check the batch status
    batch_status = client.batches.retrieve(batch_id)
    if batch_status.status != "completed":
        logger.info(f"Batch status for {batch_id}: {batch_status.status}")
        return batch_status.status

    # Retrieve the output file contents
    file_id = batch_status.output_file_id
    file_response = client.files.content(file_id)

    # Parse the JSON lines in the output file
    results = []
    for line in file_response.text.splitlines():
        data = json.loads(line)  # Parse each line as JSON
        if response_body := data.get("response", {}).get("body", {}):
            content = response_body["choices"][0]["message"]["content"]
            results.append(content)

    return results
get_batch_status(batch_id)
Source code in src/tnh_scholar/openai_interface/openai_interface.py
690
691
692
693
694
def get_batch_status(batch_id):
    client = get_api_client()

    batch = client.batches.retrieve(batch_id)
    return batch.status
get_completed_batches()

Retrieve the list of active batches using the OpenAI API.

Source code in src/tnh_scholar/openai_interface/openai_interface.py
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
def get_completed_batches() -> List[Dict]:
    """
    Retrieve the list of active batches using the OpenAI API.
    """
    client = get_api_client()

    try:
        batches = client.batches.list(limit=MAX_BATCH_LIST)
        batch_list = []
        for batch in batches:
            if batch.status == "completed":
                batch_info = {
                    "id": batch.id,
                    "status": batch.status,
                    "created_at": batch.created_at,
                    "output_file_id": batch.output_file_id,
                    "metadata": batch.metadata,
                    # Add other relevant attributes as needed
                }
                batch_list.append(batch_info)
        return batch_list
    except Exception as e:
        logger.error(f"Error fetching active batches: {e}", exc_info=True)
        return []
get_completion_content(chat_completion)
Source code in src/tnh_scholar/openai_interface/openai_interface.py
250
251
def get_completion_content(chat_completion):
    return chat_completion.choices[0].message.content
get_completion_object(chat_completion)
Source code in src/tnh_scholar/openai_interface/openai_interface.py
254
255
def get_completion_object(chat_completion):
    return chat_completion.choices[0].message.parsed
get_last_batch_response(n=0)
Source code in src/tnh_scholar/openai_interface/openai_interface.py
784
785
786
787
def get_last_batch_response(n: int = 0):
    assert n < MAX_BATCH_LIST
    completed = get_completed_batches()
    return get_batch_response(completed[n]["id"])
get_model_settings(model, parameter)
Source code in src/tnh_scholar/openai_interface/openai_interface.py
92
93
def get_model_settings(model, parameter):
    return open_ai_model_settings[model][parameter]
poll_batch_for_response(batch_id, interval=10, timeout=3600, backoff_factor=1.3, max_interval=600)

Poll the batch status until it completes, fails, or expires.

Parameters:

Name Type Description Default
batch_id str

The ID of the batch to poll.

required
interval int

Initial time (in seconds) to wait between polls. Default is 10 seconds.

10
timeout int

Maximum duration (in seconds) to poll before timing out. Use 1 hour as default.

3600
backoff_factor int

Factor by which the interval increases after each poll.

1.3
max_interval int

Maximum polling interval in seconds.

600

Returns:

Name Type Description
list bool | list

The batch response if successful.

bool bool | list

Returns False if the batch fails, times out, or expires.

Raises:

Type Description
RuntimeError

If the batch ID is not found or if an unexpected error occurs.

Source code in src/tnh_scholar/openai_interface/openai_interface.py
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
def poll_batch_for_response(
    batch_id: str,
    interval: int = 10,
    timeout: int = 3600,
    backoff_factor: float = 1.3,
    max_interval: int = 600,
) -> bool | list:
    """
    Poll the batch status until it completes, fails, or expires.

    Args:
        batch_id (str): The ID of the batch to poll.
        interval (int): Initial time (in seconds) to wait between polls. Default is 10 seconds.
        timeout (int): Maximum duration (in seconds) to poll before timing out. Use 1 hour as default.
        backoff_factor (int): Factor by which the interval increases after each poll.
        max_interval (int): Maximum polling interval in seconds.

    Returns:
        list: The batch response if successful.
        bool: Returns False if the batch fails, times out, or expires.

    Raises:
        RuntimeError: If the batch ID is not found or if an unexpected error occurs.
    """
    start_time = time.time()
    logger.info(f"Polling batch status for batch ID {batch_id} ...")

    attempts = 0
    while True:
        try:
            time.sleep(interval)
            elapsed_time = time.time() - start_time

            # Check for timeout
            if elapsed_time > timeout:
                logger.error(
                    f"Polling timed out after {timeout} seconds for batch ID {batch_id}."
                )
                return False

            # Get batch status
            batch_status = get_batch_status(batch_id)
            logger.debug(f"Batch ID {batch_id} status: {batch_status}")

            if not batch_status:
                raise RuntimeError(
                    f"Batch ID {batch_id} not found or invalid response from `get_batch_status`."
                )

            # Handle completed batch
            if batch_status == "completed":
                logger.info(
                    f"Batch processing for ID {batch_id} completed successfully."
                )
                try:
                    return get_batch_response(batch_id)
                except Exception as e:
                    logger.error(
                        f"Error retrieving response for batch ID {batch_id}: {e}",
                        exc_info=True,
                    )
                    raise RuntimeError(
                        f"Failed to retrieve response for batch ID {batch_id}."
                    ) from e

            # Handle failed batch
            elif batch_status == "failed":
                logger.error(f"Batch processing for ID {batch_id} failed.")
                return False

            # Log ongoing status and adjust interval
            logger.info(
                f"Batch status: {batch_status}. Retrying in {interval} seconds..."
            )
            attempts += 1
            interval = min(floor(interval * backoff_factor), max_interval)

        except Exception as e:
            logger.error(
                f"Unexpected error while polling batch ID {batch_id}: {e}",
                exc_info=True,
            )
            raise RuntimeError(f"Error during polling for batch ID {batch_id}.") from e
run_immediate_chat_process(messages, max_tokens=0, response_format=None, model=OPEN_AI_DEFAULT_MODEL)
Source code in src/tnh_scholar/openai_interface/openai_interface.py
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
def run_immediate_chat_process(
    messages, max_tokens: int = 0, response_format=None, model=OPEN_AI_DEFAULT_MODEL
):
    client = get_api_client()

    max_model_tokens = open_ai_model_settings[model]["max_tokens"]
    if max_tokens == 0:
        max_tokens = max_model_tokens

    if max_tokens > max_model_tokens:
        logger.warning(
            "Maximum token request exceeded: {max_tokens} for model: {model}"
        )
        logger.warning(f"Setting max_tokens to model maximum: {max_model_tokens}")
        max_tokens = max_model_tokens

    try:
        return (
            client.beta.chat.completions.parse(
                messages=messages,
                model=model,
                response_format=response_format,
                max_completion_tokens=max_tokens,
            )
            if response_format
            else client.chat.completions.create(
                messages=messages,
                model=model,
                max_completion_tokens=max_tokens,
            )
        )
    except Exception as e:
        logger.error(f"Error running immediate chat: {e}", exc_info=True)
        return None
run_immediate_completion_simple(system_message, user_message, model=None, max_tokens=0, response_format=None)

Runs a single chat completion with a system message and user message.

This function simplifies the process of running a single chat completion with the OpenAI API by handling model selection, token limits, and logging. It allows for specifying a response format and handles potential ValueError exceptions during the API call.

Parameters:

Name Type Description Default
system_message str

The system message to guide the conversation.

required
user_message str

The user's message as input for the chat completion.

required
model str

The OpenAI model to use. Defaults to None, which uses the default model.

None
max_tokens int

The maximum number of tokens for the completion. Defaults to 0, which uses the model's maximum.

0
response_format dict

The desired response format. Defaults to None.

None

Returns:

Type Description

OpenAIObject | None: The chat completion response if successful, or None if a ValueError occurs.

Raises:

Type Description
ValueError

if max_tokens exceeds the model's maximum token limit.

Source code in src/tnh_scholar/openai_interface/openai_interface.py
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
def run_immediate_completion_simple(
    system_message: str,
    user_message: str,
    model=None,
    max_tokens: int = 0,
    response_format=None,
):
    """Runs a single chat completion with a system message and user message.

    This function simplifies the process of running a single chat completion with the OpenAI API by handling
    model selection, token limits, and logging. It allows for specifying a response format and handles potential
    `ValueError` exceptions during the API call.

    Args:
        system_message (str): The system message to guide the conversation.
        user_message (str): The user's message as input for the chat completion.
        model (str, optional): The OpenAI model to use. Defaults to None, which uses the default model.
        max_tokens (int, optional): The maximum number of tokens for the completion. Defaults to 0, which uses the model's maximum.
        response_format (dict, optional): The desired response format. Defaults to None.

    Returns:
        OpenAIObject | None: The chat completion response if successful, or None if a `ValueError` occurs.

    Raises:
        ValueError: if max_tokens exceeds the model's maximum token limit.
    """

    client = get_api_client()

    if not model:
        model = OPEN_AI_DEFAULT_MODEL

    max_model_tokens = open_ai_model_settings[model]["max_tokens"]
    if max_tokens == 0:
        max_tokens = max_model_tokens

    if max_tokens > max_model_tokens:
        logger.warning(
            "Maximum token request exceeded: {max_tokens} for model: {model}"
        )
        logger.warning(f"Setting max_tokens to model maximum: {max_model_tokens}")
        max_tokens = max_model_tokens

    logger.debug(f"User message content:\n{user_message[:DEBUG_DISPLAY_BUFFER]} ...")
    message_block = [
        {"role": "system", "content": system_message},
        {"role": "user", "content": user_message},
    ]

    try:
        logger.debug(
            f"Starting chat completion with response_format={response_format} and max_tokens={max_tokens}..."
        )

        return (
            client.beta.chat.completions.parse(
                messages=message_block,  # type: ignore
                model=model,
                response_format=response_format,
                max_completion_tokens=max_tokens,
            )
            if response_format
            else client.chat.completions.create(
                messages=message_block,  # type: ignore
                model=model,
                max_completion_tokens=max_tokens,
            )
        )
    except ValueError as e:
        logger.error(f"Value Error running immediate chat: {e}", exc_info=True)
        return None
run_single_batch(user_prompts, system_message, user_wrap_function=None, max_token_list=None, description='')

Generate a batch file for the OpenAI (OA) API and send it.

Parameters:

Name Type Description Default
system_message str

System message template for batch processing.

required
user_wrap_function callable

Function to wrap user input for processing pages.

None

Returns:

Name Type Description
str List[str]

Path to the created batch file.

Raises:

Type Description
Exception

If an error occurs during file processing.

Source code in src/tnh_scholar/openai_interface/openai_interface.py
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
def run_single_batch(
    user_prompts: List,
    system_message: str,
    user_wrap_function: callable = None,
    max_token_list: List[int] = None,
    description="",
) -> List[str]:
    """
    Generate a batch file for the OpenAI (OA) API and send it.

    Parameters:
        system_message (str): System message template for batch processing.
        user_wrap_function (callable): Function to wrap user input for processing pages.

    Returns:
        str: Path to the created batch file.

    Raises:
        Exception: If an error occurs during file processing.
    """

    if max_token_list is None:
        max_token_list = []
    try:
        if not user_wrap_function:
            user_wrap_function = lambda x: x

        # Generate messages for the pages
        batch_message_seq = generate_messages(
            system_message, user_wrap_function, user_prompts
        )

        batch_file = Path("./temp_batch_run.jsonl")

        # Save the batch file
        create_jsonl_file_for_batch(
            batch_message_seq, batch_file, max_token_list=max_token_list
        )
        # logger.info(f"Batch file created successfully: {output_file}")

    except Exception as e:
        logger.error(f"Error while creating immediate batch file {batch_file}: {e}")
        raise

    try:

        if not description:
            description = (
                f"Single batch process: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}"
            )
        response_list = start_batch_with_retries(batch_file, description=description)

    except Exception as e:
        logger.error(f"Failed to complete batch process: {e}", exc_info=True)
        raise

    return response_list
run_transcription_speech(audio_file, model=OPEN_AI_DEFAULT_MODEL, response_format='verbose_json', prompt='', mode='transcribe')
Source code in src/tnh_scholar/openai_interface/openai_interface.py
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
def run_transcription_speech(
    audio_file: Path,
    model: str = OPEN_AI_DEFAULT_MODEL,
    response_format="verbose_json",
    prompt="",
    mode: str = "transcribe",
):  # mode can be "transcribe" or "translate"

    client = get_api_client()

    with audio_file.open("rb") as file:
        if mode == "transcribe":
            transcript = client.audio.transcriptions.create(
                model=model, response_format=response_format, prompt=prompt, file=file
            )
        elif mode == "translate":
            transcript = client.audio.translations.create(
                model=model, response_format=response_format, prompt=prompt, file=file
            )
        else:
            logger.error(f"Invalid mode: {mode}, in speech transcription generation.")
            raise ValueError(f"'translate' or 'transcribe' expected, not {mode}.")

    return transcript
set_model_settings(model_settings_dict)
Source code in src/tnh_scholar/openai_interface/openai_interface.py
87
88
89
def set_model_settings(model_settings_dict):
    global open_ai_model_settings
    open_ai_model_settings = model_settings_dict
start_batch(jsonl_file, description='')

Starts a batch process using OpenAI's client with an optional description and JSONL batch file.

Parameters:

Name Type Description Default
jsonl_file Path

Path to the .jsonl batch file to be used as input. Must be a pathlib.Path object.

required
description str

A description for metadata to label the batch job. If None, a default description is generated with the current date-time and file name.

''

Returns:

Name Type Description
dict

A dictionary containing the batch object if successful, or an error message if failed.

Example

jsonl_file = Path("batch_requests.jsonl") start_batch(jsonl_file)

Source code in src/tnh_scholar/openai_interface/openai_interface.py
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
def start_batch(jsonl_file: Path, description=""):
    """
    Starts a batch process using OpenAI's client with an optional description and JSONL batch file.

    Args:
        jsonl_file (Path): Path to the .jsonl batch file to be used as input. Must be a pathlib.Path object.
        description (str, optional): A description for metadata to label the batch job.
                                     If None, a default description is generated with the
                                     current date-time and file name.

    Returns:
        dict: A dictionary containing the batch object if successful, or an error message if failed.

    Example:
        jsonl_file = Path("batch_requests.jsonl")
        start_batch(jsonl_file)
    """
    client = get_api_client()

    if not isinstance(jsonl_file, Path):
        raise TypeError("The 'jsonl_file' argument must be a pathlib.Path object.")

    if not jsonl_file.exists():
        raise FileNotFoundError(f"The file {jsonl_file} does not exist.")

    basename = jsonl_file.stem

    # Generate description:
    current_time = datetime.now().astimezone().strftime("%m-%d-%Y %H:%M:%S %Z")
    description = f"{current_time} | {jsonl_file.name} | {description}"

    try:
        # Attempt to create the input file for the batch process
        with jsonl_file.open("rb") as file:
            batch_input_file = client.files.create(file=file, purpose="batch")
        batch_input_file_id = batch_input_file.id
    except Exception as e:
        return {"error": f"File upload failed: {e}"}

    try:
        # Attempt to create the batch with specified input file and metadata description
        batch = client.batches.create(
            input_file_id=batch_input_file_id,
            endpoint="/v1/chat/completions",
            completion_window="24h",
            metadata={"description": description, "basename": basename},
        )

        # log the batch
        _log_batch_start_info(batch, description)
        return batch

    except Exception as e:
        return {"error": f"Batch creation failed: {e}"}
start_batch_with_retries(jsonl_file, description='', max_retries=DEFAULT_MAX_BATCH_RETRY, retry_delay=5, poll_interval=10, timeout=3600)

Starts a batch with retries and polls for its completion.

Parameters:

Name Type Description Default
jsonl_file Path

Path to the JSONL file for batch input.

required
description str

A description for the batch job (optional).

''
max_retries int

Maximum number of retries to start and complete the batch (default: 3).

DEFAULT_MAX_BATCH_RETRY
retry_delay int

Delay in seconds between retries (default: 60).

5
poll_interval int

Interval in seconds for polling batch status (default: 10).

10
timeout int

Timeout in seconds for polling (default: 23 hours).

3600

Returns:

Name Type Description
list list[str]

The batch response if completed successfully.

Raises:

Type Description
RuntimeError

If the batch fails after all retries or encounters an error.

Source code in src/tnh_scholar/openai_interface/openai_interface.py
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
def start_batch_with_retries(
    jsonl_file: Path,
    description: str = "",
    max_retries: int = DEFAULT_MAX_BATCH_RETRY,
    retry_delay: int = 5,
    poll_interval: int = 10,
    timeout: int = 3600,
) -> list[str]:
    """
    Starts a batch with retries and polls for its completion.

    Args:
        jsonl_file (Path): Path to the JSONL file for batch input.
        description (str): A description for the batch job (optional).
        max_retries (int): Maximum number of retries to start and complete the batch (default: 3).
        retry_delay (int): Delay in seconds between retries (default: 60).
        poll_interval (int): Interval in seconds for polling batch status (default: 10).
        timeout (int): Timeout in seconds for polling (default: 23 hours).

    Returns:
        list: The batch response if completed successfully.

    Raises:
        RuntimeError: If the batch fails after all retries or encounters an error.
    """
    for attempt in range(max_retries):
        try:
            # Start the batch
            batch = start_batch(jsonl_file, description=description)
            if not batch or "error" in batch:
                raise RuntimeError(
                    f"Failed to start batch: {batch.get('error', 'Unknown error')}"
                )

            batch_id = batch.id
            if not batch_id:
                raise RuntimeError("Batch started but no ID was returned.")

            logger.info(
                f"Batch started: attempt {attempt + 1}.",
                extra={"batch_id": batch_id, "description": description},
            )

            # Poll for batch completion
            response_list = poll_batch_for_response(
                batch_id, interval=poll_interval, timeout=timeout
            )

            # Check for a response
            if response_list:
                logger.info(
                    f"Batch completed successfully after {attempt + 1} attempts.",
                    extra={"batch_id": batch_id, "description": description},
                )
                break  # exit for loop

            else:  # No response means batch failed. Retry.
                logger.error(
                    f"Attempt {attempt + 1} failed. Retrying batch process in {retry_delay} seconds...",
                    extra={
                        "attempt": attempt + 1,
                        "max_retries": max_retries,
                        "description": description,
                    },
                )
                time.sleep(retry_delay)

        except Exception as e:
            logger.error(
                f"Batch start and polling failed on attempt {attempt + 1}: {e}",
                exc_info=True,
                extra={"attempt": attempt + 1, "description": description},
            )
            time.sleep(retry_delay)

    else:  # else the loop completed before succesful result
        logger.error(
            f"Failed to complete batch after {max_retries} retries.",
            extra={"description": description},
        )
        raise RuntimeError(
            f"Error: Failed to complete batch after {max_retries} retries."
        )

    return response_list
token_count(text)
Source code in src/tnh_scholar/openai_interface/openai_interface.py
74
75
def token_count(text):
    return len(open_ai_encoding.encode(text))
token_count_file(text_file)
Source code in src/tnh_scholar/openai_interface/openai_interface.py
78
79
80
def token_count_file(text_file: Path):
    text = get_text_from_file(text_file)
    return token_count(text)

run_oa_batch_jobs

BATCH_JOB_PATH = Path('UNSET') module-attribute
CHECK_INTERVAL_SECONDS = 60 module-attribute
ENQUEUED_BATCH_TOKEN_LIMIT = 90000 module-attribute
enqueued_tokens = 0 module-attribute
sent_batches = {} module-attribute
calculate_enqueued_tokens(active_batches)

Calculate the total number of enqueued tokens from active batches.

Source code in src/tnh_scholar/openai_interface/run_oa_batch_jobs.py
119
120
121
122
123
124
125
126
def calculate_enqueued_tokens(active_batches: List[Dict]) -> int:
    """
    Calculate the total number of enqueued tokens from active batches.
    """
    total_tokens = 0
    for batch in active_batches:
        total_tokens += batch.get("input_tokens", 0)
    return total_tokens
download_batch_result(client, batch_id)

Download the result of a completed batch.

Source code in src/tnh_scholar/openai_interface/run_oa_batch_jobs.py
129
130
131
132
133
134
135
136
137
138
139
140
141
def download_batch_result(client, batch_id):
    """
    Download the result of a completed batch.
    """
    try:
        response = client.get(f"/v1/batches/{batch_id}")
        result = response.get("result", {})
        output_file = f"batch_results_{batch_id}.json"
        with open(output_file, "w") as file:
            json.dump(result, file, indent=4)
        print(f"Batch {batch_id} completed. Result saved to {output_file}.")
    except Exception as e:
        print(f"Error downloading result for batch {batch_id}: {e}")
get_active_batches(client)

Retrieve the list of active batches using the OpenAI API.

Source code in src/tnh_scholar/openai_interface/run_oa_batch_jobs.py
20
21
22
23
24
25
26
27
28
29
def get_active_batches(client) -> List[Dict]:
    """
    Retrieve the list of active batches using the OpenAI API.
    """
    try:
        response = client.get("/v1/batches")
        return response.get("data", [])
    except Exception as e:
        print(f"Error fetching active batches: {e}")
        return []
main()

Main function to manage and monitor batch jobs.

Source code in src/tnh_scholar/openai_interface/run_oa_batch_jobs.py
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
def main():
    """
    Main function to manage and monitor batch jobs.
    """
    global enqueued_tokens, sent_batches
    client = set_api_client()
    if not client:
        print("Failed to initialize API client. Exiting.")
        return

    batch_file_directory = "./journal_cleaning_batches"

    while True:
        # Poll for completed batches
        print("Polling for completed batches...")
        poll_batches(client)

        # Calculate remaining tokens
        remaining_tokens = ENQUEUED_BATCH_TOKEN_LIMIT - enqueued_tokens
        print(f"Remaining tokens: {remaining_tokens}")

        # Enqueue new batches if there's space
        print("Processing batch files...")
        process_batch_files(client, batch_file_directory, remaining_tokens)

        # Wait for the next polling cycle
        print(f"Waiting for {CHECK_INTERVAL_SECONDS} seconds before next check...")
        time.sleep(CHECK_INTERVAL_SECONDS)
poll_batches(client)

Poll for completed batches and update global enqueued_tokens.

Source code in src/tnh_scholar/openai_interface/run_oa_batch_jobs.py
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
def poll_batches(client):
    """
    Poll for completed batches and update global enqueued_tokens.
    """
    global enqueued_tokens, sent_batches
    completed_batches = []

    for batch_id, info in sent_batches.items():
        batch = info["batch"]
        try:
            response = client.get(f"/v1/batches/{batch_id}")
            status = response.get("status")
            if status == "completed":
                download_batch_result(client, batch_id)
                enqueued_tokens -= info["token_size"]
                completed_batches.append(batch_id)
            elif status == "failed":
                print(f"Batch {batch_id} failed. Removing from tracking.")
                enqueued_tokens -= info["token_size"]
                completed_batches.append(batch_id)
        except Exception as e:
            print(f"Error checking status for batch {batch_id}: {e}")

    # Remove completed batches from sent_batches
    for batch_id in completed_batches:
        del sent_batches[batch_id]
process_batch_files(client, batch_file_directory, remaining_tokens)

Process batch files in the batch job directory, enqueue new batches if space permits.

Source code in src/tnh_scholar/openai_interface/run_oa_batch_jobs.py
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
def process_batch_files(client, batch_file_directory, remaining_tokens):
    """
    Process batch files in the batch job directory, enqueue new batches if space permits.
    """
    global enqueued_tokens, sent_batches
    batch_info = []

    for path_obj in Path(batch_file_directory).iterdir():
        regex = re.compile(r"^.*\.jsonl$")
        if path_obj.is_file() and regex.search(path_obj.name):
            batch_file = Path(batch_file_directory) / path_obj.name
            print(f"Found batch file: {batch_file}")

            # Calculate the token count for this batch
            try:
                with open(batch_file, "r") as file:
                    data = file.read()
                    batch_tokens = token_count(data)
            except Exception as e:
                print(f"Failed to calculate token count for {batch_file}: {e}")
                continue

            # Enqueue batch if there's space
            if batch_tokens <= remaining_tokens:
                try:
                    batch = start_batch(client, batch_file)
                    sent_batches[batch["id"]] = {
                        "batch": batch,
                        "token_size": batch_tokens,
                    }
                    enqueued_tokens += batch_tokens
                    remaining_tokens -= batch_tokens
                    print(f"Batch enqueued: {batch['id']}")
                except Exception as e:
                    print(f"Failed to enqueue batch {batch_file}: {e}")
            else:
                print(f"Insufficient token space for {batch_file}. Skipping.")
    return batch_info

text_processing

__all__ = ['bracket_lines', 'unbracket_lines', 'lines_from_bracketed_text', 'NumberedText', 'normalize_newlines', 'clean_text'] module-attribute

NumberedText

Represents a text document with numbered lines for easy reference and manipulation.

Provides utilities for working with line-numbered text including reading, writing, accessing lines by number, and iterating over numbered lines.

Attributes:

Name Type Description
lines List[str]

List of text lines

start int

Starting line number (default: 1)

separator str

Separator between line number and content (default: ": ")

Examples:

>>> text = "First line\nSecond line\n\nFourth line"
>>> doc = NumberedText(text)
>>> print(doc)
1: First line
2: Second line
3:
4: Fourth line
>>> print(doc.get_line(2))
Second line
>>> for num, line in doc:
...     print(f"Line {num}: {len(line)} chars")
Source code in src/tnh_scholar/text_processing/numbered_text.py
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
class NumberedText:
    """
    Represents a text document with numbered lines for easy reference and manipulation.

    Provides utilities for working with line-numbered text including reading,
    writing, accessing lines by number, and iterating over numbered lines.

    Attributes:
        lines (List[str]): List of text lines
        start (int): Starting line number (default: 1)
        separator (str): Separator between line number and content (default: ": ")

    Examples:
        >>> text = "First line\\nSecond line\\n\\nFourth line"
        >>> doc = NumberedText(text)
        >>> print(doc)
        1: First line
        2: Second line
        3:
        4: Fourth line

        >>> print(doc.get_line(2))
        Second line

        >>> for num, line in doc:
        ...     print(f"Line {num}: {len(line)} chars")
    """

    @dataclass
    class LineSegment:
        """
        Represents a segment of lines with start and end indices in 1-based indexing.

        The segment follows Python range conventions where start is inclusive and
        end is exclusive. However, indexing is 1-based to match NumberedText.

        Attributes:
            start: Starting line number (inclusive, 1-based)
            end: Ending line number (exclusive, 1-based)
        """

        start: int
        end: int

        def __iter__(self):
            """Allow unpacking into start, end pairs."""
            yield self.start
            yield self.end

    class SegmentIterator:
        """
        Iterator for generating line segments of specified size.

        Produces segments of lines with start/end indices following 1-based indexing.
        The final segment may be smaller than the specified segment size.

        Attributes:
            total_lines: Total number of lines in text
            segment_size: Number of lines per segment
            start_line: Starting line number (1-based)
            min_segment_size: Minimum size for the final segment
        """

        def __init__(
            self,
            total_lines: int,
            segment_size: int,
            start_line: int = 1,
            min_segment_size: Optional[int] = None,
        ):
            """
            Initialize the segment iterator.

            Args:
                total_lines: Total number of lines to iterate over
                segment_size: Desired size of each segment
                start_line: First line number (default: 1)
                min_segment_size: Minimum size for final segment (default: None)
                    If specified, the last segment will be merged with the previous one
                    if it would be smaller than this size.

            Raises:
                ValueError: If segment_size < 1 or total_lines < 1
                ValueError: If start_line < 1 (must use 1-based indexing)
                ValueError: If min_segment_size >= segment_size
            """
            if segment_size < 1:
                raise ValueError("Segment size must be at least 1")
            if total_lines < 1:
                raise ValueError("Total lines must be at least 1")
            if start_line < 1:
                raise ValueError("Start line must be at least 1 (1-based indexing)")
            if min_segment_size is not None and min_segment_size >= segment_size:
                raise ValueError("Minimum segment size must be less than segment size")

            self.total_lines = total_lines
            self.segment_size = segment_size
            self.start_line = start_line
            self.min_segment_size = min_segment_size

            # Calculate number of segments
            remaining_lines = total_lines - start_line + 1
            self.num_segments = (remaining_lines + segment_size - 1) // segment_size

        def __iter__(self) -> Iterator["NumberedText.LineSegment"]:
            """
            Iterate over line segments.

            Yields:
                LineSegment containing start (inclusive) and end (exclusive) indices
            """
            current = self.start_line

            for i in range(self.num_segments):
                is_last_segment = i == self.num_segments - 1
                segment_end = min(current + self.segment_size, self.total_lines + 1)

                # Handle minimum segment size for last segment
                if (
                    is_last_segment
                    and self.min_segment_size is not None
                    and segment_end - current < self.min_segment_size
                    and i > 0
                ):
                    # Merge with previous segment by not yielding
                    break

                yield NumberedText.LineSegment(current, segment_end)
                current = segment_end

    def __init__(
        self, content: Optional[str] = None, start: int = 1, separator: str = ":"
    ) -> None:
        """
        Initialize a numbered text document, detecting and preserving existing numbering.

        Valid numbered text must have:
        - Sequential line numbers
        - Consistent separator character(s)
        - Every non-empty line must follow the numbering pattern

        Args:
            content: Initial text content, if any
            start: Starting line number (used only if content isn't already numbered)
            separator: Separator between line numbers and content (only if content isn't numbered)

        Examples:
            >>> # Custom separators
            >>> doc = NumberedText("1→First line\\n2→Second line")
            >>> doc.separator == "→"
            True

            >>> # Preserves starting number
            >>> doc = NumberedText("5#First\\n6#Second")
            >>> doc.start == 5
            True

            >>> # Regular numbered list isn't treated as line numbers
            >>> doc = NumberedText("1. First item\\n2. Second item")
            >>> doc.numbered_lines
            ['1: 1. First item', '2: 2. Second item']
        """

        self.lines: List[str] = []  # Declare lines here
        self.start: int = start  # Declare start with its type
        self.separator: str = separator  # and separator

        if not isinstance(content, str):
            raise ValueError("NumberedText requires string input.")

        if start < 1:  # enforce 1 based indexing.
            raise IndexError(
                "NumberedText: Numbered lines must begin on an integer great or equal to 1."
            )

        if not content:
            return

        # Analyze the text format
        is_numbered, detected_sep, start_num = get_numbered_format(content)

        format_info = get_numbered_format(content)

        if format_info.is_numbered:
            self.start = format_info.start_num  # type: ignore
            self.separator = format_info.separator  # type: ignore

            # Extract content by removing number and separator
            pattern = re.compile(rf"^\d+{re.escape(detected_sep)}")
            self.lines = []

            for line in content.splitlines():
                if line.strip():
                    self.lines.append(pattern.sub("", line))
                else:
                    self.lines.append(line)
        else:
            self.lines = content.splitlines()
            self.start = start
            self.separator = separator

    @classmethod
    def from_file(cls, path: Path, **kwargs) -> "NumberedText":
        """Create a NumberedText instance from a file."""
        return cls(Path(path).read_text(), **kwargs)

    def _format_line(self, line_num: int, line: str) -> str:
        return f"{line_num}{self.separator}{line}"

    def _to_internal_index(self, idx: int) -> int:
        """return the index into the lines object in Python 0-based indexing."""
        if idx > 0:
            return idx - self.start
        elif idx < 0:  # allow negative indexing to index from end
            if abs(idx) > self.size:
                raise IndexError(f"NumberedText: negative index out of range: {idx}")
            return self.end + idx  # convert to logical positive location for reference.
        else:
            raise IndexError("NumberedText: Index cannot be zero in 1-based indexing.")

    def __str__(self) -> str:
        """Return the numbered text representation."""
        return "\n".join(
            self._format_line(i, line) for i, line in enumerate(self.lines, self.start)
        )

    def __len__(self) -> int:
        """Return the number of lines."""
        return len(self.lines)

    def __iter__(self) -> Iterator[tuple[int, str]]:
        """Iterate over (line_number, line_content) pairs."""
        return iter((i, line) for i, line in enumerate(self.lines, self.start))

    def __getitem__(self, index: int) -> str:
        """Get line content by line number (1-based indexing)."""
        return self.lines[self._to_internal_index(index)]

    def get_line(self, line_num: int) -> str:
        """Get content of specified line number."""
        return self[line_num]

    def _to_line_index(self, internal_index: int) -> int:
        return self.start + self._to_internal_index(internal_index)

    def get_numbered_line(self, line_num: int) -> str:
        """Get specified line with line number."""
        idx = self._to_line_index(line_num)
        return self._format_line(idx, self[idx])

    def get_lines(self, start: int, end: int) -> List[str]:
        """Get content of line range, not inclusive of end line."""
        return self.lines[self._to_internal_index(start) : self._to_internal_index(end)]

    def get_numbered_lines(self, start: int, end: int) -> List[str]:
        return [
            self._format_line(i + self._to_internal_index(start) + 1, line)
            for i, line in enumerate(self.get_lines(start, end))
        ]

    def get_segment(self, start: int, end: int) -> str:
        if start < self.start:
            raise IndexError(f"Start index {start} is before first line {self.start}")
        if end > len(self) + 1:
            raise IndexError(f"End index {end} is past last line {len(self)}")
        if start >= end:
            raise IndexError(f"Start index {start} must be less than end index {end}")
        return "\n".join(self.get_lines(start, end))

    def iter_segments(
        self, segment_size: int, min_segment_size: Optional[int] = None
    ) -> Iterator[LineSegment]:
        """
        Iterate over segments of the text with specified size.

        Args:
            segment_size: Number of lines per segment
            min_segment_size: Optional minimum size for final segment.
                If specified, last segment will be merged with previous one
                if it would be smaller than this size.

        Yields:
            LineSegment objects containing start and end line numbers

        Example:
            >>> text = NumberedText("line1\\nline2\\nline3\\nline4\\nline5")
            >>> for segment in text.iter_segments(2):
            ...     print(f"Lines {segment.start}-{segment.end}")
            Lines 1-3
            Lines 3-5
            Lines 5-6
        """
        iterator = self.SegmentIterator(
            len(self), segment_size, self.start, min_segment_size
        )
        return iter(iterator)

    def get_numbered_segment(self, start: int, end: int) -> str:
        return "\n".join(self.get_numbered_lines(start, end))

    def save(self, path: Path, numbered: bool = True) -> None:
        """
        Save document to file.

        Args:
            path: Output file path
            numbered: Whether to save with line numbers (default: True)
        """
        content = str(self) if numbered else "\n".join(self.lines)
        Path(path).write_text(content)

    def append(self, text: str) -> None:
        """Append text, splitting into lines if needed."""
        self.lines.extend(text.splitlines())

    def insert(self, line_num: int, text: str) -> None:
        """Insert text at specified line number. Assumes text is not empty."""
        new_lines = text.splitlines()
        internal_idx = self._to_internal_index(line_num)
        self.lines[internal_idx:internal_idx] = new_lines

    @property
    def content(self) -> str:
        """Get original text without line numbers."""
        return "\n".join(self.lines)

    @property
    def size(self) -> int:
        """Get the number of lines."""
        return len(self.lines)

    @property
    def numbered_lines(self) -> List[str]:
        """
        Get list of lines with line numbers included.

        Returns:
            List[str]: Lines with numbers and separator prefixed

        Examples:
            >>> doc = NumberedText("First line\\nSecond line")
            >>> doc.numbered_lines
            ['1: First line', '2: Second line']

        Note:
            - Unlike str(self), this returns a list rather than joined string
            - Maintains consistent formatting with separator
            - Useful for processing or displaying individual numbered lines
        """
        return [
            f"{i}{self.separator}{line}"
            for i, line in enumerate(self.lines, self.start)
        ]

    @property
    def end(self) -> int:
        return self.start + len(self.lines) - 1
content property

Get original text without line numbers.

end property
lines = [] instance-attribute
numbered_lines property

Get list of lines with line numbers included.

Returns:

Type Description
List[str]

List[str]: Lines with numbers and separator prefixed

Examples:

>>> doc = NumberedText("First line\nSecond line")
>>> doc.numbered_lines
['1: First line', '2: Second line']
Note
  • Unlike str(self), this returns a list rather than joined string
  • Maintains consistent formatting with separator
  • Useful for processing or displaying individual numbered lines
separator = separator instance-attribute
size property

Get the number of lines.

start = start instance-attribute
LineSegment dataclass

Represents a segment of lines with start and end indices in 1-based indexing.

The segment follows Python range conventions where start is inclusive and end is exclusive. However, indexing is 1-based to match NumberedText.

Attributes:

Name Type Description
start int

Starting line number (inclusive, 1-based)

end int

Ending line number (exclusive, 1-based)

Source code in src/tnh_scholar/text_processing/numbered_text.py
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
@dataclass
class LineSegment:
    """
    Represents a segment of lines with start and end indices in 1-based indexing.

    The segment follows Python range conventions where start is inclusive and
    end is exclusive. However, indexing is 1-based to match NumberedText.

    Attributes:
        start: Starting line number (inclusive, 1-based)
        end: Ending line number (exclusive, 1-based)
    """

    start: int
    end: int

    def __iter__(self):
        """Allow unpacking into start, end pairs."""
        yield self.start
        yield self.end
end instance-attribute
start instance-attribute
__init__(start, end)
__iter__()

Allow unpacking into start, end pairs.

Source code in src/tnh_scholar/text_processing/numbered_text.py
57
58
59
60
def __iter__(self):
    """Allow unpacking into start, end pairs."""
    yield self.start
    yield self.end
SegmentIterator

Iterator for generating line segments of specified size.

Produces segments of lines with start/end indices following 1-based indexing. The final segment may be smaller than the specified segment size.

Attributes:

Name Type Description
total_lines

Total number of lines in text

segment_size

Number of lines per segment

start_line

Starting line number (1-based)

min_segment_size

Minimum size for the final segment

Source code in src/tnh_scholar/text_processing/numbered_text.py
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
class SegmentIterator:
    """
    Iterator for generating line segments of specified size.

    Produces segments of lines with start/end indices following 1-based indexing.
    The final segment may be smaller than the specified segment size.

    Attributes:
        total_lines: Total number of lines in text
        segment_size: Number of lines per segment
        start_line: Starting line number (1-based)
        min_segment_size: Minimum size for the final segment
    """

    def __init__(
        self,
        total_lines: int,
        segment_size: int,
        start_line: int = 1,
        min_segment_size: Optional[int] = None,
    ):
        """
        Initialize the segment iterator.

        Args:
            total_lines: Total number of lines to iterate over
            segment_size: Desired size of each segment
            start_line: First line number (default: 1)
            min_segment_size: Minimum size for final segment (default: None)
                If specified, the last segment will be merged with the previous one
                if it would be smaller than this size.

        Raises:
            ValueError: If segment_size < 1 or total_lines < 1
            ValueError: If start_line < 1 (must use 1-based indexing)
            ValueError: If min_segment_size >= segment_size
        """
        if segment_size < 1:
            raise ValueError("Segment size must be at least 1")
        if total_lines < 1:
            raise ValueError("Total lines must be at least 1")
        if start_line < 1:
            raise ValueError("Start line must be at least 1 (1-based indexing)")
        if min_segment_size is not None and min_segment_size >= segment_size:
            raise ValueError("Minimum segment size must be less than segment size")

        self.total_lines = total_lines
        self.segment_size = segment_size
        self.start_line = start_line
        self.min_segment_size = min_segment_size

        # Calculate number of segments
        remaining_lines = total_lines - start_line + 1
        self.num_segments = (remaining_lines + segment_size - 1) // segment_size

    def __iter__(self) -> Iterator["NumberedText.LineSegment"]:
        """
        Iterate over line segments.

        Yields:
            LineSegment containing start (inclusive) and end (exclusive) indices
        """
        current = self.start_line

        for i in range(self.num_segments):
            is_last_segment = i == self.num_segments - 1
            segment_end = min(current + self.segment_size, self.total_lines + 1)

            # Handle minimum segment size for last segment
            if (
                is_last_segment
                and self.min_segment_size is not None
                and segment_end - current < self.min_segment_size
                and i > 0
            ):
                # Merge with previous segment by not yielding
                break

            yield NumberedText.LineSegment(current, segment_end)
            current = segment_end
min_segment_size = min_segment_size instance-attribute
num_segments = remaining_lines + segment_size - 1 // segment_size instance-attribute
segment_size = segment_size instance-attribute
start_line = start_line instance-attribute
total_lines = total_lines instance-attribute
__init__(total_lines, segment_size, start_line=1, min_segment_size=None)

Initialize the segment iterator.

Parameters:

Name Type Description Default
total_lines int

Total number of lines to iterate over

required
segment_size int

Desired size of each segment

required
start_line int

First line number (default: 1)

1
min_segment_size Optional[int]

Minimum size for final segment (default: None) If specified, the last segment will be merged with the previous one if it would be smaller than this size.

None

Raises:

Type Description
ValueError

If segment_size < 1 or total_lines < 1

ValueError

If start_line < 1 (must use 1-based indexing)

ValueError

If min_segment_size >= segment_size

Source code in src/tnh_scholar/text_processing/numbered_text.py
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
def __init__(
    self,
    total_lines: int,
    segment_size: int,
    start_line: int = 1,
    min_segment_size: Optional[int] = None,
):
    """
    Initialize the segment iterator.

    Args:
        total_lines: Total number of lines to iterate over
        segment_size: Desired size of each segment
        start_line: First line number (default: 1)
        min_segment_size: Minimum size for final segment (default: None)
            If specified, the last segment will be merged with the previous one
            if it would be smaller than this size.

    Raises:
        ValueError: If segment_size < 1 or total_lines < 1
        ValueError: If start_line < 1 (must use 1-based indexing)
        ValueError: If min_segment_size >= segment_size
    """
    if segment_size < 1:
        raise ValueError("Segment size must be at least 1")
    if total_lines < 1:
        raise ValueError("Total lines must be at least 1")
    if start_line < 1:
        raise ValueError("Start line must be at least 1 (1-based indexing)")
    if min_segment_size is not None and min_segment_size >= segment_size:
        raise ValueError("Minimum segment size must be less than segment size")

    self.total_lines = total_lines
    self.segment_size = segment_size
    self.start_line = start_line
    self.min_segment_size = min_segment_size

    # Calculate number of segments
    remaining_lines = total_lines - start_line + 1
    self.num_segments = (remaining_lines + segment_size - 1) // segment_size
__iter__()

Iterate over line segments.

Yields:

Type Description
LineSegment

LineSegment containing start (inclusive) and end (exclusive) indices

Source code in src/tnh_scholar/text_processing/numbered_text.py
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
def __iter__(self) -> Iterator["NumberedText.LineSegment"]:
    """
    Iterate over line segments.

    Yields:
        LineSegment containing start (inclusive) and end (exclusive) indices
    """
    current = self.start_line

    for i in range(self.num_segments):
        is_last_segment = i == self.num_segments - 1
        segment_end = min(current + self.segment_size, self.total_lines + 1)

        # Handle minimum segment size for last segment
        if (
            is_last_segment
            and self.min_segment_size is not None
            and segment_end - current < self.min_segment_size
            and i > 0
        ):
            # Merge with previous segment by not yielding
            break

        yield NumberedText.LineSegment(current, segment_end)
        current = segment_end
__getitem__(index)

Get line content by line number (1-based indexing).

Source code in src/tnh_scholar/text_processing/numbered_text.py
247
248
249
def __getitem__(self, index: int) -> str:
    """Get line content by line number (1-based indexing)."""
    return self.lines[self._to_internal_index(index)]
__init__(content=None, start=1, separator=':')

Initialize a numbered text document, detecting and preserving existing numbering.

Valid numbered text must have: - Sequential line numbers - Consistent separator character(s) - Every non-empty line must follow the numbering pattern

Parameters:

Name Type Description Default
content Optional[str]

Initial text content, if any

None
start int

Starting line number (used only if content isn't already numbered)

1
separator str

Separator between line numbers and content (only if content isn't numbered)

':'

Examples:

>>> # Custom separators
>>> doc = NumberedText("1→First line\n2→Second line")
>>> doc.separator == "→"
True
>>> # Preserves starting number
>>> doc = NumberedText("5#First\n6#Second")
>>> doc.start == 5
True
>>> # Regular numbered list isn't treated as line numbers
>>> doc = NumberedText("1. First item\n2. Second item")
>>> doc.numbered_lines
['1: 1. First item', '2: 2. Second item']
Source code in src/tnh_scholar/text_processing/numbered_text.py
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
def __init__(
    self, content: Optional[str] = None, start: int = 1, separator: str = ":"
) -> None:
    """
    Initialize a numbered text document, detecting and preserving existing numbering.

    Valid numbered text must have:
    - Sequential line numbers
    - Consistent separator character(s)
    - Every non-empty line must follow the numbering pattern

    Args:
        content: Initial text content, if any
        start: Starting line number (used only if content isn't already numbered)
        separator: Separator between line numbers and content (only if content isn't numbered)

    Examples:
        >>> # Custom separators
        >>> doc = NumberedText("1→First line\\n2→Second line")
        >>> doc.separator == "→"
        True

        >>> # Preserves starting number
        >>> doc = NumberedText("5#First\\n6#Second")
        >>> doc.start == 5
        True

        >>> # Regular numbered list isn't treated as line numbers
        >>> doc = NumberedText("1. First item\\n2. Second item")
        >>> doc.numbered_lines
        ['1: 1. First item', '2: 2. Second item']
    """

    self.lines: List[str] = []  # Declare lines here
    self.start: int = start  # Declare start with its type
    self.separator: str = separator  # and separator

    if not isinstance(content, str):
        raise ValueError("NumberedText requires string input.")

    if start < 1:  # enforce 1 based indexing.
        raise IndexError(
            "NumberedText: Numbered lines must begin on an integer great or equal to 1."
        )

    if not content:
        return

    # Analyze the text format
    is_numbered, detected_sep, start_num = get_numbered_format(content)

    format_info = get_numbered_format(content)

    if format_info.is_numbered:
        self.start = format_info.start_num  # type: ignore
        self.separator = format_info.separator  # type: ignore

        # Extract content by removing number and separator
        pattern = re.compile(rf"^\d+{re.escape(detected_sep)}")
        self.lines = []

        for line in content.splitlines():
            if line.strip():
                self.lines.append(pattern.sub("", line))
            else:
                self.lines.append(line)
    else:
        self.lines = content.splitlines()
        self.start = start
        self.separator = separator
__iter__()

Iterate over (line_number, line_content) pairs.

Source code in src/tnh_scholar/text_processing/numbered_text.py
243
244
245
def __iter__(self) -> Iterator[tuple[int, str]]:
    """Iterate over (line_number, line_content) pairs."""
    return iter((i, line) for i, line in enumerate(self.lines, self.start))
__len__()

Return the number of lines.

Source code in src/tnh_scholar/text_processing/numbered_text.py
239
240
241
def __len__(self) -> int:
    """Return the number of lines."""
    return len(self.lines)
__str__()

Return the numbered text representation.

Source code in src/tnh_scholar/text_processing/numbered_text.py
233
234
235
236
237
def __str__(self) -> str:
    """Return the numbered text representation."""
    return "\n".join(
        self._format_line(i, line) for i, line in enumerate(self.lines, self.start)
    )
append(text)

Append text, splitting into lines if needed.

Source code in src/tnh_scholar/text_processing/numbered_text.py
324
325
326
def append(self, text: str) -> None:
    """Append text, splitting into lines if needed."""
    self.lines.extend(text.splitlines())
from_file(path, **kwargs) classmethod

Create a NumberedText instance from a file.

Source code in src/tnh_scholar/text_processing/numbered_text.py
214
215
216
217
@classmethod
def from_file(cls, path: Path, **kwargs) -> "NumberedText":
    """Create a NumberedText instance from a file."""
    return cls(Path(path).read_text(), **kwargs)
get_line(line_num)

Get content of specified line number.

Source code in src/tnh_scholar/text_processing/numbered_text.py
251
252
253
def get_line(self, line_num: int) -> str:
    """Get content of specified line number."""
    return self[line_num]
get_lines(start, end)

Get content of line range, not inclusive of end line.

Source code in src/tnh_scholar/text_processing/numbered_text.py
263
264
265
def get_lines(self, start: int, end: int) -> List[str]:
    """Get content of line range, not inclusive of end line."""
    return self.lines[self._to_internal_index(start) : self._to_internal_index(end)]
get_numbered_line(line_num)

Get specified line with line number.

Source code in src/tnh_scholar/text_processing/numbered_text.py
258
259
260
261
def get_numbered_line(self, line_num: int) -> str:
    """Get specified line with line number."""
    idx = self._to_line_index(line_num)
    return self._format_line(idx, self[idx])
get_numbered_lines(start, end)
Source code in src/tnh_scholar/text_processing/numbered_text.py
267
268
269
270
271
def get_numbered_lines(self, start: int, end: int) -> List[str]:
    return [
        self._format_line(i + self._to_internal_index(start) + 1, line)
        for i, line in enumerate(self.get_lines(start, end))
    ]
get_numbered_segment(start, end)
Source code in src/tnh_scholar/text_processing/numbered_text.py
310
311
def get_numbered_segment(self, start: int, end: int) -> str:
    return "\n".join(self.get_numbered_lines(start, end))
get_segment(start, end)
Source code in src/tnh_scholar/text_processing/numbered_text.py
273
274
275
276
277
278
279
280
def get_segment(self, start: int, end: int) -> str:
    if start < self.start:
        raise IndexError(f"Start index {start} is before first line {self.start}")
    if end > len(self) + 1:
        raise IndexError(f"End index {end} is past last line {len(self)}")
    if start >= end:
        raise IndexError(f"Start index {start} must be less than end index {end}")
    return "\n".join(self.get_lines(start, end))
insert(line_num, text)

Insert text at specified line number. Assumes text is not empty.

Source code in src/tnh_scholar/text_processing/numbered_text.py
328
329
330
331
332
def insert(self, line_num: int, text: str) -> None:
    """Insert text at specified line number. Assumes text is not empty."""
    new_lines = text.splitlines()
    internal_idx = self._to_internal_index(line_num)
    self.lines[internal_idx:internal_idx] = new_lines
iter_segments(segment_size, min_segment_size=None)

Iterate over segments of the text with specified size.

Parameters:

Name Type Description Default
segment_size int

Number of lines per segment

required
min_segment_size Optional[int]

Optional minimum size for final segment. If specified, last segment will be merged with previous one if it would be smaller than this size.

None

Yields:

Type Description
LineSegment

LineSegment objects containing start and end line numbers

Example

text = NumberedText("line1\nline2\nline3\nline4\nline5") for segment in text.iter_segments(2): ... print(f"Lines {segment.start}-{segment.end}") Lines 1-3 Lines 3-5 Lines 5-6

Source code in src/tnh_scholar/text_processing/numbered_text.py
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
def iter_segments(
    self, segment_size: int, min_segment_size: Optional[int] = None
) -> Iterator[LineSegment]:
    """
    Iterate over segments of the text with specified size.

    Args:
        segment_size: Number of lines per segment
        min_segment_size: Optional minimum size for final segment.
            If specified, last segment will be merged with previous one
            if it would be smaller than this size.

    Yields:
        LineSegment objects containing start and end line numbers

    Example:
        >>> text = NumberedText("line1\\nline2\\nline3\\nline4\\nline5")
        >>> for segment in text.iter_segments(2):
        ...     print(f"Lines {segment.start}-{segment.end}")
        Lines 1-3
        Lines 3-5
        Lines 5-6
    """
    iterator = self.SegmentIterator(
        len(self), segment_size, self.start, min_segment_size
    )
    return iter(iterator)
save(path, numbered=True)

Save document to file.

Parameters:

Name Type Description Default
path Path

Output file path

required
numbered bool

Whether to save with line numbers (default: True)

True
Source code in src/tnh_scholar/text_processing/numbered_text.py
313
314
315
316
317
318
319
320
321
322
def save(self, path: Path, numbered: bool = True) -> None:
    """
    Save document to file.

    Args:
        path: Output file path
        numbered: Whether to save with line numbers (default: True)
    """
    content = str(self) if numbered else "\n".join(self.lines)
    Path(path).write_text(content)

bracket_lines(text, number=False)

Encloses each line of the input text with angle brackets.
If number is True, adds a line number followed by a colon `:` and then the line.

Args:
    text (str): The input string containing lines separated by '

'. number (bool): Whether to prepend line numbers to each line.

Returns:
    str: A string where each line is enclosed in angle brackets.

Examples:
    >>> bracket_lines("This is a string with

two lines.") ' < two lines.>'

    >>> bracket_lines("This is a string with

two lines.", number=True) '<1:This is a string with> <2: two lines.>'

Source code in src/tnh_scholar/text_processing/bracket.py
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
def bracket_lines(text: str, number: bool = False) -> str:
    """
    Encloses each line of the input text with angle brackets.
    If number is True, adds a line number followed by a colon `:` and then the line.

    Args:
        text (str): The input string containing lines separated by '\n'.
        number (bool): Whether to prepend line numbers to each line.

    Returns:
        str: A string where each line is enclosed in angle brackets.

    Examples:
        >>> bracket_lines("This is a string with\n   two lines.")
        '<This is a string with>\n<   two lines.>'

        >>> bracket_lines("This is a string with\n   two lines.", number=True)
        '<1:This is a string with>\n<2:   two lines.>'
    """
    return "\n".join(
        f"<{f'{i+1}:{line}' if number else line}>"
        for i, line in enumerate(text.split("\n"))
    )

clean_text(text, newline=False)

Cleans a given text by replacing specific unwanted characters such as tab, and non-breaking spaces with regular spaces.

This function takes a string as input and applies replacements based on a predefined mapping of characters to replace.

Parameters:

Name Type Description Default
text str

The text to be cleaned.

required

Returns:

Name Type Description
str

The cleaned text with unwanted characters replaced by spaces.

Example

text = "This is\n an example\ttext with\xa0extra spaces." clean_text(text) 'This is an example text with extra spaces.'

Source code in src/tnh_scholar/text_processing/text_processing.py
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
def clean_text(text, newline=False):
    """
    Cleans a given text by replacing specific unwanted characters such as
    tab, and non-breaking spaces with regular spaces.

    This function takes a string as input and applies replacements
    based on a predefined mapping of characters to replace.

    Args:
        text (str): The text to be cleaned.

    Returns:
        str: The cleaned text with unwanted characters replaced by spaces.

    Example:
        >>> text = "This is\\n an example\\ttext with\\xa0extra spaces."
        >>> clean_text(text)
        'This is an example text with extra spaces.'

    """
    # Define a mapping of characters to replace
    replace_map = {
        "\t": " ",  # Replace tabs with space
        "\xa0": " ",  # Replace non-breaking space with regular space
        # Add more replacements as needed
    }

    if newline:
        replace_map["\n"] = ""  # remove newlines

    # Loop through the replace map and replace each character
    for old_char, new_char in replace_map.items():
        text = text.replace(old_char, new_char)

    return text.strip()  # Ensure any leading/trailing spaces are removed

lines_from_bracketed_text(text, start, end, keep_brackets=False)

Extracts lines from bracketed text between the start and end indices, inclusive.
Handles both numbered and non-numbered cases.

Args:
    text (str): The input bracketed text containing lines like <...>.
    start (int): The starting line number (1-based).
    end (int): The ending line number (1-based).

Returns:
    list[str]: The lines from start to end inclusive, with angle brackets removed.

Raises:
    FormattingError: If the text contains improperly formatted lines (missing angle brackets).
    ValueError: If start or end indices are invalid or out of bounds.

Examples:
    >>> text = "<1:Line 1>

<2:Line 2> <3:Line 3>" >>> lines_from_bracketed_text(text, 1, 2) ['Line 1', 'Line 2']

    >>> text = "<Line 1>

" >>> lines_from_bracketed_text(text, 2, 3) ['Line 2', 'Line 3']

Source code in src/tnh_scholar/text_processing/bracket.py
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
def lines_from_bracketed_text(
    text: str, start: int, end: int, keep_brackets=False
) -> list[str]:
    """
    Extracts lines from bracketed text between the start and end indices, inclusive.
    Handles both numbered and non-numbered cases.

    Args:
        text (str): The input bracketed text containing lines like <...>.
        start (int): The starting line number (1-based).
        end (int): The ending line number (1-based).

    Returns:
        list[str]: The lines from start to end inclusive, with angle brackets removed.

    Raises:
        FormattingError: If the text contains improperly formatted lines (missing angle brackets).
        ValueError: If start or end indices are invalid or out of bounds.

    Examples:
        >>> text = "<1:Line 1>\n<2:Line 2>\n<3:Line 3>"
        >>> lines_from_bracketed_text(text, 1, 2)
        ['Line 1', 'Line 2']

        >>> text = "<Line 1>\n<Line 2>\n<Line 3>"
        >>> lines_from_bracketed_text(text, 2, 3)
        ['Line 2', 'Line 3']
    """
    # Split the text into lines
    lines = text.splitlines()

    # Validate indices
    if start < 1 or end < 1 or start > end or end > len(lines):
        raise ValueError(
            "Invalid start or end indices for the given text: start:{start}, end: {end}"
        )

    # Extract lines and validate formatting
    result = []
    for i, line in enumerate(lines, start=1):
        if start <= i <= end:
            # Check for proper bracketing and extract the content
            match = re.match(r"<(\d+:)?(.*?)>", line)
            if not match:
                raise FormattingError(f"Invalid format for line {i}: '{line}'")
            # Add the extracted content (group 2) to the result
            if keep_brackets:
                result.append(line)
            else:
                result.append(match[2].strip())

    return "\n".join(result)

normalize_newlines(text, spacing=2)

Normalize newline blocks in the input text by reducing consecutive newlines
to the specified number of newlines for consistent readability and formatting.

Parameters:
----------
text : str
    The input text containing inconsistent newline spacing.
spacing : int, optional
    The number of newlines to insert between lines. Defaults to 2.

Returns:
-------
str
    The text with consecutive newlines reduced to the specified number of newlines.

Example:
--------
>>> raw_text = "Heading

Paragraph text 1 Paragraph text 2

" >>> normalize_newlines(raw_text, spacing=2) 'Heading

Paragraph text 1

Paragraph text 2

'

Source code in src/tnh_scholar/text_processing/text_processing.py
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
def normalize_newlines(text: str, spacing: int = 2) -> str:
    """
    Normalize newline blocks in the input text by reducing consecutive newlines
    to the specified number of newlines for consistent readability and formatting.

    Parameters:
    ----------
    text : str
        The input text containing inconsistent newline spacing.
    spacing : int, optional
        The number of newlines to insert between lines. Defaults to 2.

    Returns:
    -------
    str
        The text with consecutive newlines reduced to the specified number of newlines.

    Example:
    --------
    >>> raw_text = "Heading\n\n\nParagraph text 1\nParagraph text 2\n\n\n"
    >>> normalize_newlines(raw_text, spacing=2)
    'Heading\n\nParagraph text 1\n\nParagraph text 2\n\n'
    """
    # Replace one or more newlines with the desired number of newlines
    newlines = "\n" * spacing
    return re.sub(r"\n{1,}", newlines, text)

unbracket_lines(text, number=False)

Removes angle brackets (< >) from encapsulated lines and optionally removes line numbers.

Args:
    text (str): The input string with encapsulated lines.
    number (bool): If True, removes line numbers in the format 'digit:'.
                   Raises a ValueError if `number=True` and a line does not start with a digit followed by a colon.

Returns:
    str: A newline-separated string with the encapsulation removed, and line numbers stripped if specified.

Examples:
    >>> unbracket_lines("<1:Line 1>

<2:Line 2>", number=True) 'Line 1 Line 2'

    >>> unbracket_lines("<Line 1>

") 'Line 1 Line 2'

    >>> unbracket_lines("<1Line 1>", number=True)
    ValueError: Line does not start with a valid number: '1Line 1'
Source code in src/tnh_scholar/text_processing/bracket.py
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
def unbracket_lines(text: str, number: bool = False) -> str:
    """
    Removes angle brackets (< >) from encapsulated lines and optionally removes line numbers.

    Args:
        text (str): The input string with encapsulated lines.
        number (bool): If True, removes line numbers in the format 'digit:'.
                       Raises a ValueError if `number=True` and a line does not start with a digit followed by a colon.

    Returns:
        str: A newline-separated string with the encapsulation removed, and line numbers stripped if specified.

    Examples:
        >>> unbracket_lines("<1:Line 1>\n<2:Line 2>", number=True)
        'Line 1\nLine 2'

        >>> unbracket_lines("<Line 1>\n<Line 2>")
        'Line 1\nLine 2'

        >>> unbracket_lines("<1Line 1>", number=True)
        ValueError: Line does not start with a valid number: '1Line 1'
    """
    unbracketed_lines = []

    for line in text.splitlines():
        match = (
            re.match(r"<(\d+):(.*?)>", line) if number else re.match(r"<(.*?)>", line)
        )
        if match:
            content = match[2].strip() if number else match[1].strip()
            unbracketed_lines.append(content)
        elif number:
            raise FormattingError(f"Line does not start with a valid number: '{line}'")
        else:
            raise FormattingError(f"Line does not follow the expected format: '{line}'")

    return "\n".join(unbracketed_lines)

bracket

FormattingError

Bases: Exception

Custom exception raised for formatting-related errors.

Source code in src/tnh_scholar/text_processing/bracket.py
 5
 6
 7
 8
 9
10
11
class FormattingError(Exception):
    """
    Custom exception raised for formatting-related errors.
    """

    def __init__(self, message="An error occurred due to invalid formatting."):
        super().__init__(message)
__init__(message='An error occurred due to invalid formatting.')
Source code in src/tnh_scholar/text_processing/bracket.py
10
11
def __init__(self, message="An error occurred due to invalid formatting."):
    super().__init__(message)
bracket_all_lines(pages)
Source code in src/tnh_scholar/text_processing/bracket.py
78
79
def bracket_all_lines(pages):
    return [bracket_lines(page) for page in pages]
bracket_lines(text, number=False)
Encloses each line of the input text with angle brackets.
If number is True, adds a line number followed by a colon `:` and then the line.

Args:
    text (str): The input string containing lines separated by '

'. number (bool): Whether to prepend line numbers to each line.

Returns:
    str: A string where each line is enclosed in angle brackets.

Examples:
    >>> bracket_lines("This is a string with

two lines.") ' < two lines.>'

    >>> bracket_lines("This is a string with

two lines.", number=True) '<1:This is a string with> <2: two lines.>'

Source code in src/tnh_scholar/text_processing/bracket.py
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
def bracket_lines(text: str, number: bool = False) -> str:
    """
    Encloses each line of the input text with angle brackets.
    If number is True, adds a line number followed by a colon `:` and then the line.

    Args:
        text (str): The input string containing lines separated by '\n'.
        number (bool): Whether to prepend line numbers to each line.

    Returns:
        str: A string where each line is enclosed in angle brackets.

    Examples:
        >>> bracket_lines("This is a string with\n   two lines.")
        '<This is a string with>\n<   two lines.>'

        >>> bracket_lines("This is a string with\n   two lines.", number=True)
        '<1:This is a string with>\n<2:   two lines.>'
    """
    return "\n".join(
        f"<{f'{i+1}:{line}' if number else line}>"
        for i, line in enumerate(text.split("\n"))
    )
lines_from_bracketed_text(text, start, end, keep_brackets=False)
Extracts lines from bracketed text between the start and end indices, inclusive.
Handles both numbered and non-numbered cases.

Args:
    text (str): The input bracketed text containing lines like <...>.
    start (int): The starting line number (1-based).
    end (int): The ending line number (1-based).

Returns:
    list[str]: The lines from start to end inclusive, with angle brackets removed.

Raises:
    FormattingError: If the text contains improperly formatted lines (missing angle brackets).
    ValueError: If start or end indices are invalid or out of bounds.

Examples:
    >>> text = "<1:Line 1>

<2:Line 2> <3:Line 3>" >>> lines_from_bracketed_text(text, 1, 2) ['Line 1', 'Line 2']

    >>> text = "<Line 1>

" >>> lines_from_bracketed_text(text, 2, 3) ['Line 2', 'Line 3']

Source code in src/tnh_scholar/text_processing/bracket.py
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
def lines_from_bracketed_text(
    text: str, start: int, end: int, keep_brackets=False
) -> list[str]:
    """
    Extracts lines from bracketed text between the start and end indices, inclusive.
    Handles both numbered and non-numbered cases.

    Args:
        text (str): The input bracketed text containing lines like <...>.
        start (int): The starting line number (1-based).
        end (int): The ending line number (1-based).

    Returns:
        list[str]: The lines from start to end inclusive, with angle brackets removed.

    Raises:
        FormattingError: If the text contains improperly formatted lines (missing angle brackets).
        ValueError: If start or end indices are invalid or out of bounds.

    Examples:
        >>> text = "<1:Line 1>\n<2:Line 2>\n<3:Line 3>"
        >>> lines_from_bracketed_text(text, 1, 2)
        ['Line 1', 'Line 2']

        >>> text = "<Line 1>\n<Line 2>\n<Line 3>"
        >>> lines_from_bracketed_text(text, 2, 3)
        ['Line 2', 'Line 3']
    """
    # Split the text into lines
    lines = text.splitlines()

    # Validate indices
    if start < 1 or end < 1 or start > end or end > len(lines):
        raise ValueError(
            "Invalid start or end indices for the given text: start:{start}, end: {end}"
        )

    # Extract lines and validate formatting
    result = []
    for i, line in enumerate(lines, start=1):
        if start <= i <= end:
            # Check for proper bracketing and extract the content
            match = re.match(r"<(\d+:)?(.*?)>", line)
            if not match:
                raise FormattingError(f"Invalid format for line {i}: '{line}'")
            # Add the extracted content (group 2) to the result
            if keep_brackets:
                result.append(line)
            else:
                result.append(match[2].strip())

    return "\n".join(result)
number_lines(text, start=1, separator=': ')

Numbers each line of text with a readable format, including empty lines.

Parameters:

Name Type Description Default
text str

Input text to be numbered. Can be multi-line.

required
start int

Starting line number. Defaults to 1.

1
separator str

Separator between line number and content. Defaults to ": ".

': '

Returns:

Name Type Description
str str

Numbered text where each line starts with "{number}: ".

Examples:

>>> text = "First line\nSecond line\n\nFourth line"
>>> print(number_lines(text))
1: First line
2: Second line
3:
4: Fourth line
>>> print(number_lines(text, start=5, separator=" | "))
5 | First line
6 | Second line
7 |
8 | Fourth line
Notes
  • All lines are numbered, including empty lines, to maintain text structure
  • Line numbers are aligned through natural string formatting
  • Customizable separator allows for different formatting needs
  • Can start from any line number for flexibility in text processing
Source code in src/tnh_scholar/text_processing/bracket.py
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
def number_lines(text: str, start: int = 1, separator: str = ": ") -> str:
    """
    Numbers each line of text with a readable format, including empty lines.

    Args:
        text (str): Input text to be numbered. Can be multi-line.
        start (int, optional): Starting line number. Defaults to 1.
        separator (str, optional): Separator between line number and content.
            Defaults to ": ".

    Returns:
        str: Numbered text where each line starts with "{number}: ".

    Examples:
        >>> text = "First line\\nSecond line\\n\\nFourth line"
        >>> print(number_lines(text))
        1: First line
        2: Second line
        3:
        4: Fourth line

        >>> print(number_lines(text, start=5, separator=" | "))
        5 | First line
        6 | Second line
        7 |
        8 | Fourth line

    Notes:
        - All lines are numbered, including empty lines, to maintain text structure
        - Line numbers are aligned through natural string formatting
        - Customizable separator allows for different formatting needs
        - Can start from any line number for flexibility in text processing
    """
    lines = text.splitlines()
    return "\n".join(f"{i}{separator}{line}" for i, line in enumerate(lines, start))
unbracket_all_lines(pages)
Source code in src/tnh_scholar/text_processing/bracket.py
121
122
123
124
125
126
127
128
def unbracket_all_lines(pages):
    result = []
    for page in pages:
        if page == "blank page":
            result.append(page)
        else:
            result.append(unbracket_lines(page))
    return result
unbracket_lines(text, number=False)
Removes angle brackets (< >) from encapsulated lines and optionally removes line numbers.

Args:
    text (str): The input string with encapsulated lines.
    number (bool): If True, removes line numbers in the format 'digit:'.
                   Raises a ValueError if `number=True` and a line does not start with a digit followed by a colon.

Returns:
    str: A newline-separated string with the encapsulation removed, and line numbers stripped if specified.

Examples:
    >>> unbracket_lines("<1:Line 1>

<2:Line 2>", number=True) 'Line 1 Line 2'

    >>> unbracket_lines("<Line 1>

") 'Line 1 Line 2'

    >>> unbracket_lines("<1Line 1>", number=True)
    ValueError: Line does not start with a valid number: '1Line 1'
Source code in src/tnh_scholar/text_processing/bracket.py
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
def unbracket_lines(text: str, number: bool = False) -> str:
    """
    Removes angle brackets (< >) from encapsulated lines and optionally removes line numbers.

    Args:
        text (str): The input string with encapsulated lines.
        number (bool): If True, removes line numbers in the format 'digit:'.
                       Raises a ValueError if `number=True` and a line does not start with a digit followed by a colon.

    Returns:
        str: A newline-separated string with the encapsulation removed, and line numbers stripped if specified.

    Examples:
        >>> unbracket_lines("<1:Line 1>\n<2:Line 2>", number=True)
        'Line 1\nLine 2'

        >>> unbracket_lines("<Line 1>\n<Line 2>")
        'Line 1\nLine 2'

        >>> unbracket_lines("<1Line 1>", number=True)
        ValueError: Line does not start with a valid number: '1Line 1'
    """
    unbracketed_lines = []

    for line in text.splitlines():
        match = (
            re.match(r"<(\d+):(.*?)>", line) if number else re.match(r"<(.*?)>", line)
        )
        if match:
            content = match[2].strip() if number else match[1].strip()
            unbracketed_lines.append(content)
        elif number:
            raise FormattingError(f"Line does not start with a valid number: '{line}'")
        else:
            raise FormattingError(f"Line does not follow the expected format: '{line}'")

    return "\n".join(unbracketed_lines)

numbered_text

NumberedFormat

Bases: NamedTuple

Source code in src/tnh_scholar/text_processing/numbered_text.py
 7
 8
 9
10
class NumberedFormat(NamedTuple):
    is_numbered: bool
    separator: Optional[str] = None
    start_num: Optional[int] = None
is_numbered instance-attribute
separator = None class-attribute instance-attribute
start_num = None class-attribute instance-attribute
NumberedText

Represents a text document with numbered lines for easy reference and manipulation.

Provides utilities for working with line-numbered text including reading, writing, accessing lines by number, and iterating over numbered lines.

Attributes:

Name Type Description
lines List[str]

List of text lines

start int

Starting line number (default: 1)

separator str

Separator between line number and content (default: ": ")

Examples:

>>> text = "First line\nSecond line\n\nFourth line"
>>> doc = NumberedText(text)
>>> print(doc)
1: First line
2: Second line
3:
4: Fourth line
>>> print(doc.get_line(2))
Second line
>>> for num, line in doc:
...     print(f"Line {num}: {len(line)} chars")
Source code in src/tnh_scholar/text_processing/numbered_text.py
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
class NumberedText:
    """
    Represents a text document with numbered lines for easy reference and manipulation.

    Provides utilities for working with line-numbered text including reading,
    writing, accessing lines by number, and iterating over numbered lines.

    Attributes:
        lines (List[str]): List of text lines
        start (int): Starting line number (default: 1)
        separator (str): Separator between line number and content (default: ": ")

    Examples:
        >>> text = "First line\\nSecond line\\n\\nFourth line"
        >>> doc = NumberedText(text)
        >>> print(doc)
        1: First line
        2: Second line
        3:
        4: Fourth line

        >>> print(doc.get_line(2))
        Second line

        >>> for num, line in doc:
        ...     print(f"Line {num}: {len(line)} chars")
    """

    @dataclass
    class LineSegment:
        """
        Represents a segment of lines with start and end indices in 1-based indexing.

        The segment follows Python range conventions where start is inclusive and
        end is exclusive. However, indexing is 1-based to match NumberedText.

        Attributes:
            start: Starting line number (inclusive, 1-based)
            end: Ending line number (exclusive, 1-based)
        """

        start: int
        end: int

        def __iter__(self):
            """Allow unpacking into start, end pairs."""
            yield self.start
            yield self.end

    class SegmentIterator:
        """
        Iterator for generating line segments of specified size.

        Produces segments of lines with start/end indices following 1-based indexing.
        The final segment may be smaller than the specified segment size.

        Attributes:
            total_lines: Total number of lines in text
            segment_size: Number of lines per segment
            start_line: Starting line number (1-based)
            min_segment_size: Minimum size for the final segment
        """

        def __init__(
            self,
            total_lines: int,
            segment_size: int,
            start_line: int = 1,
            min_segment_size: Optional[int] = None,
        ):
            """
            Initialize the segment iterator.

            Args:
                total_lines: Total number of lines to iterate over
                segment_size: Desired size of each segment
                start_line: First line number (default: 1)
                min_segment_size: Minimum size for final segment (default: None)
                    If specified, the last segment will be merged with the previous one
                    if it would be smaller than this size.

            Raises:
                ValueError: If segment_size < 1 or total_lines < 1
                ValueError: If start_line < 1 (must use 1-based indexing)
                ValueError: If min_segment_size >= segment_size
            """
            if segment_size < 1:
                raise ValueError("Segment size must be at least 1")
            if total_lines < 1:
                raise ValueError("Total lines must be at least 1")
            if start_line < 1:
                raise ValueError("Start line must be at least 1 (1-based indexing)")
            if min_segment_size is not None and min_segment_size >= segment_size:
                raise ValueError("Minimum segment size must be less than segment size")

            self.total_lines = total_lines
            self.segment_size = segment_size
            self.start_line = start_line
            self.min_segment_size = min_segment_size

            # Calculate number of segments
            remaining_lines = total_lines - start_line + 1
            self.num_segments = (remaining_lines + segment_size - 1) // segment_size

        def __iter__(self) -> Iterator["NumberedText.LineSegment"]:
            """
            Iterate over line segments.

            Yields:
                LineSegment containing start (inclusive) and end (exclusive) indices
            """
            current = self.start_line

            for i in range(self.num_segments):
                is_last_segment = i == self.num_segments - 1
                segment_end = min(current + self.segment_size, self.total_lines + 1)

                # Handle minimum segment size for last segment
                if (
                    is_last_segment
                    and self.min_segment_size is not None
                    and segment_end - current < self.min_segment_size
                    and i > 0
                ):
                    # Merge with previous segment by not yielding
                    break

                yield NumberedText.LineSegment(current, segment_end)
                current = segment_end

    def __init__(
        self, content: Optional[str] = None, start: int = 1, separator: str = ":"
    ) -> None:
        """
        Initialize a numbered text document, detecting and preserving existing numbering.

        Valid numbered text must have:
        - Sequential line numbers
        - Consistent separator character(s)
        - Every non-empty line must follow the numbering pattern

        Args:
            content: Initial text content, if any
            start: Starting line number (used only if content isn't already numbered)
            separator: Separator between line numbers and content (only if content isn't numbered)

        Examples:
            >>> # Custom separators
            >>> doc = NumberedText("1→First line\\n2→Second line")
            >>> doc.separator == "→"
            True

            >>> # Preserves starting number
            >>> doc = NumberedText("5#First\\n6#Second")
            >>> doc.start == 5
            True

            >>> # Regular numbered list isn't treated as line numbers
            >>> doc = NumberedText("1. First item\\n2. Second item")
            >>> doc.numbered_lines
            ['1: 1. First item', '2: 2. Second item']
        """

        self.lines: List[str] = []  # Declare lines here
        self.start: int = start  # Declare start with its type
        self.separator: str = separator  # and separator

        if not isinstance(content, str):
            raise ValueError("NumberedText requires string input.")

        if start < 1:  # enforce 1 based indexing.
            raise IndexError(
                "NumberedText: Numbered lines must begin on an integer great or equal to 1."
            )

        if not content:
            return

        # Analyze the text format
        is_numbered, detected_sep, start_num = get_numbered_format(content)

        format_info = get_numbered_format(content)

        if format_info.is_numbered:
            self.start = format_info.start_num  # type: ignore
            self.separator = format_info.separator  # type: ignore

            # Extract content by removing number and separator
            pattern = re.compile(rf"^\d+{re.escape(detected_sep)}")
            self.lines = []

            for line in content.splitlines():
                if line.strip():
                    self.lines.append(pattern.sub("", line))
                else:
                    self.lines.append(line)
        else:
            self.lines = content.splitlines()
            self.start = start
            self.separator = separator

    @classmethod
    def from_file(cls, path: Path, **kwargs) -> "NumberedText":
        """Create a NumberedText instance from a file."""
        return cls(Path(path).read_text(), **kwargs)

    def _format_line(self, line_num: int, line: str) -> str:
        return f"{line_num}{self.separator}{line}"

    def _to_internal_index(self, idx: int) -> int:
        """return the index into the lines object in Python 0-based indexing."""
        if idx > 0:
            return idx - self.start
        elif idx < 0:  # allow negative indexing to index from end
            if abs(idx) > self.size:
                raise IndexError(f"NumberedText: negative index out of range: {idx}")
            return self.end + idx  # convert to logical positive location for reference.
        else:
            raise IndexError("NumberedText: Index cannot be zero in 1-based indexing.")

    def __str__(self) -> str:
        """Return the numbered text representation."""
        return "\n".join(
            self._format_line(i, line) for i, line in enumerate(self.lines, self.start)
        )

    def __len__(self) -> int:
        """Return the number of lines."""
        return len(self.lines)

    def __iter__(self) -> Iterator[tuple[int, str]]:
        """Iterate over (line_number, line_content) pairs."""
        return iter((i, line) for i, line in enumerate(self.lines, self.start))

    def __getitem__(self, index: int) -> str:
        """Get line content by line number (1-based indexing)."""
        return self.lines[self._to_internal_index(index)]

    def get_line(self, line_num: int) -> str:
        """Get content of specified line number."""
        return self[line_num]

    def _to_line_index(self, internal_index: int) -> int:
        return self.start + self._to_internal_index(internal_index)

    def get_numbered_line(self, line_num: int) -> str:
        """Get specified line with line number."""
        idx = self._to_line_index(line_num)
        return self._format_line(idx, self[idx])

    def get_lines(self, start: int, end: int) -> List[str]:
        """Get content of line range, not inclusive of end line."""
        return self.lines[self._to_internal_index(start) : self._to_internal_index(end)]

    def get_numbered_lines(self, start: int, end: int) -> List[str]:
        return [
            self._format_line(i + self._to_internal_index(start) + 1, line)
            for i, line in enumerate(self.get_lines(start, end))
        ]

    def get_segment(self, start: int, end: int) -> str:
        if start < self.start:
            raise IndexError(f"Start index {start} is before first line {self.start}")
        if end > len(self) + 1:
            raise IndexError(f"End index {end} is past last line {len(self)}")
        if start >= end:
            raise IndexError(f"Start index {start} must be less than end index {end}")
        return "\n".join(self.get_lines(start, end))

    def iter_segments(
        self, segment_size: int, min_segment_size: Optional[int] = None
    ) -> Iterator[LineSegment]:
        """
        Iterate over segments of the text with specified size.

        Args:
            segment_size: Number of lines per segment
            min_segment_size: Optional minimum size for final segment.
                If specified, last segment will be merged with previous one
                if it would be smaller than this size.

        Yields:
            LineSegment objects containing start and end line numbers

        Example:
            >>> text = NumberedText("line1\\nline2\\nline3\\nline4\\nline5")
            >>> for segment in text.iter_segments(2):
            ...     print(f"Lines {segment.start}-{segment.end}")
            Lines 1-3
            Lines 3-5
            Lines 5-6
        """
        iterator = self.SegmentIterator(
            len(self), segment_size, self.start, min_segment_size
        )
        return iter(iterator)

    def get_numbered_segment(self, start: int, end: int) -> str:
        return "\n".join(self.get_numbered_lines(start, end))

    def save(self, path: Path, numbered: bool = True) -> None:
        """
        Save document to file.

        Args:
            path: Output file path
            numbered: Whether to save with line numbers (default: True)
        """
        content = str(self) if numbered else "\n".join(self.lines)
        Path(path).write_text(content)

    def append(self, text: str) -> None:
        """Append text, splitting into lines if needed."""
        self.lines.extend(text.splitlines())

    def insert(self, line_num: int, text: str) -> None:
        """Insert text at specified line number. Assumes text is not empty."""
        new_lines = text.splitlines()
        internal_idx = self._to_internal_index(line_num)
        self.lines[internal_idx:internal_idx] = new_lines

    @property
    def content(self) -> str:
        """Get original text without line numbers."""
        return "\n".join(self.lines)

    @property
    def size(self) -> int:
        """Get the number of lines."""
        return len(self.lines)

    @property
    def numbered_lines(self) -> List[str]:
        """
        Get list of lines with line numbers included.

        Returns:
            List[str]: Lines with numbers and separator prefixed

        Examples:
            >>> doc = NumberedText("First line\\nSecond line")
            >>> doc.numbered_lines
            ['1: First line', '2: Second line']

        Note:
            - Unlike str(self), this returns a list rather than joined string
            - Maintains consistent formatting with separator
            - Useful for processing or displaying individual numbered lines
        """
        return [
            f"{i}{self.separator}{line}"
            for i, line in enumerate(self.lines, self.start)
        ]

    @property
    def end(self) -> int:
        return self.start + len(self.lines) - 1
content property

Get original text without line numbers.

end property
lines = [] instance-attribute
numbered_lines property

Get list of lines with line numbers included.

Returns:

Type Description
List[str]

List[str]: Lines with numbers and separator prefixed

Examples:

>>> doc = NumberedText("First line\nSecond line")
>>> doc.numbered_lines
['1: First line', '2: Second line']
Note
  • Unlike str(self), this returns a list rather than joined string
  • Maintains consistent formatting with separator
  • Useful for processing or displaying individual numbered lines
separator = separator instance-attribute
size property

Get the number of lines.

start = start instance-attribute
LineSegment dataclass

Represents a segment of lines with start and end indices in 1-based indexing.

The segment follows Python range conventions where start is inclusive and end is exclusive. However, indexing is 1-based to match NumberedText.

Attributes:

Name Type Description
start int

Starting line number (inclusive, 1-based)

end int

Ending line number (exclusive, 1-based)

Source code in src/tnh_scholar/text_processing/numbered_text.py
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
@dataclass
class LineSegment:
    """
    Represents a segment of lines with start and end indices in 1-based indexing.

    The segment follows Python range conventions where start is inclusive and
    end is exclusive. However, indexing is 1-based to match NumberedText.

    Attributes:
        start: Starting line number (inclusive, 1-based)
        end: Ending line number (exclusive, 1-based)
    """

    start: int
    end: int

    def __iter__(self):
        """Allow unpacking into start, end pairs."""
        yield self.start
        yield self.end
end instance-attribute
start instance-attribute
__init__(start, end)
__iter__()

Allow unpacking into start, end pairs.

Source code in src/tnh_scholar/text_processing/numbered_text.py
57
58
59
60
def __iter__(self):
    """Allow unpacking into start, end pairs."""
    yield self.start
    yield self.end
SegmentIterator

Iterator for generating line segments of specified size.

Produces segments of lines with start/end indices following 1-based indexing. The final segment may be smaller than the specified segment size.

Attributes:

Name Type Description
total_lines

Total number of lines in text

segment_size

Number of lines per segment

start_line

Starting line number (1-based)

min_segment_size

Minimum size for the final segment

Source code in src/tnh_scholar/text_processing/numbered_text.py
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
class SegmentIterator:
    """
    Iterator for generating line segments of specified size.

    Produces segments of lines with start/end indices following 1-based indexing.
    The final segment may be smaller than the specified segment size.

    Attributes:
        total_lines: Total number of lines in text
        segment_size: Number of lines per segment
        start_line: Starting line number (1-based)
        min_segment_size: Minimum size for the final segment
    """

    def __init__(
        self,
        total_lines: int,
        segment_size: int,
        start_line: int = 1,
        min_segment_size: Optional[int] = None,
    ):
        """
        Initialize the segment iterator.

        Args:
            total_lines: Total number of lines to iterate over
            segment_size: Desired size of each segment
            start_line: First line number (default: 1)
            min_segment_size: Minimum size for final segment (default: None)
                If specified, the last segment will be merged with the previous one
                if it would be smaller than this size.

        Raises:
            ValueError: If segment_size < 1 or total_lines < 1
            ValueError: If start_line < 1 (must use 1-based indexing)
            ValueError: If min_segment_size >= segment_size
        """
        if segment_size < 1:
            raise ValueError("Segment size must be at least 1")
        if total_lines < 1:
            raise ValueError("Total lines must be at least 1")
        if start_line < 1:
            raise ValueError("Start line must be at least 1 (1-based indexing)")
        if min_segment_size is not None and min_segment_size >= segment_size:
            raise ValueError("Minimum segment size must be less than segment size")

        self.total_lines = total_lines
        self.segment_size = segment_size
        self.start_line = start_line
        self.min_segment_size = min_segment_size

        # Calculate number of segments
        remaining_lines = total_lines - start_line + 1
        self.num_segments = (remaining_lines + segment_size - 1) // segment_size

    def __iter__(self) -> Iterator["NumberedText.LineSegment"]:
        """
        Iterate over line segments.

        Yields:
            LineSegment containing start (inclusive) and end (exclusive) indices
        """
        current = self.start_line

        for i in range(self.num_segments):
            is_last_segment = i == self.num_segments - 1
            segment_end = min(current + self.segment_size, self.total_lines + 1)

            # Handle minimum segment size for last segment
            if (
                is_last_segment
                and self.min_segment_size is not None
                and segment_end - current < self.min_segment_size
                and i > 0
            ):
                # Merge with previous segment by not yielding
                break

            yield NumberedText.LineSegment(current, segment_end)
            current = segment_end
min_segment_size = min_segment_size instance-attribute
num_segments = remaining_lines + segment_size - 1 // segment_size instance-attribute
segment_size = segment_size instance-attribute
start_line = start_line instance-attribute
total_lines = total_lines instance-attribute
__init__(total_lines, segment_size, start_line=1, min_segment_size=None)

Initialize the segment iterator.

Parameters:

Name Type Description Default
total_lines int

Total number of lines to iterate over

required
segment_size int

Desired size of each segment

required
start_line int

First line number (default: 1)

1
min_segment_size Optional[int]

Minimum size for final segment (default: None) If specified, the last segment will be merged with the previous one if it would be smaller than this size.

None

Raises:

Type Description
ValueError

If segment_size < 1 or total_lines < 1

ValueError

If start_line < 1 (must use 1-based indexing)

ValueError

If min_segment_size >= segment_size

Source code in src/tnh_scholar/text_processing/numbered_text.py
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
def __init__(
    self,
    total_lines: int,
    segment_size: int,
    start_line: int = 1,
    min_segment_size: Optional[int] = None,
):
    """
    Initialize the segment iterator.

    Args:
        total_lines: Total number of lines to iterate over
        segment_size: Desired size of each segment
        start_line: First line number (default: 1)
        min_segment_size: Minimum size for final segment (default: None)
            If specified, the last segment will be merged with the previous one
            if it would be smaller than this size.

    Raises:
        ValueError: If segment_size < 1 or total_lines < 1
        ValueError: If start_line < 1 (must use 1-based indexing)
        ValueError: If min_segment_size >= segment_size
    """
    if segment_size < 1:
        raise ValueError("Segment size must be at least 1")
    if total_lines < 1:
        raise ValueError("Total lines must be at least 1")
    if start_line < 1:
        raise ValueError("Start line must be at least 1 (1-based indexing)")
    if min_segment_size is not None and min_segment_size >= segment_size:
        raise ValueError("Minimum segment size must be less than segment size")

    self.total_lines = total_lines
    self.segment_size = segment_size
    self.start_line = start_line
    self.min_segment_size = min_segment_size

    # Calculate number of segments
    remaining_lines = total_lines - start_line + 1
    self.num_segments = (remaining_lines + segment_size - 1) // segment_size
__iter__()

Iterate over line segments.

Yields:

Type Description
LineSegment

LineSegment containing start (inclusive) and end (exclusive) indices

Source code in src/tnh_scholar/text_processing/numbered_text.py
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
def __iter__(self) -> Iterator["NumberedText.LineSegment"]:
    """
    Iterate over line segments.

    Yields:
        LineSegment containing start (inclusive) and end (exclusive) indices
    """
    current = self.start_line

    for i in range(self.num_segments):
        is_last_segment = i == self.num_segments - 1
        segment_end = min(current + self.segment_size, self.total_lines + 1)

        # Handle minimum segment size for last segment
        if (
            is_last_segment
            and self.min_segment_size is not None
            and segment_end - current < self.min_segment_size
            and i > 0
        ):
            # Merge with previous segment by not yielding
            break

        yield NumberedText.LineSegment(current, segment_end)
        current = segment_end
__getitem__(index)

Get line content by line number (1-based indexing).

Source code in src/tnh_scholar/text_processing/numbered_text.py
247
248
249
def __getitem__(self, index: int) -> str:
    """Get line content by line number (1-based indexing)."""
    return self.lines[self._to_internal_index(index)]
__init__(content=None, start=1, separator=':')

Initialize a numbered text document, detecting and preserving existing numbering.

Valid numbered text must have: - Sequential line numbers - Consistent separator character(s) - Every non-empty line must follow the numbering pattern

Parameters:

Name Type Description Default
content Optional[str]

Initial text content, if any

None
start int

Starting line number (used only if content isn't already numbered)

1
separator str

Separator between line numbers and content (only if content isn't numbered)

':'

Examples:

>>> # Custom separators
>>> doc = NumberedText("1→First line\n2→Second line")
>>> doc.separator == "→"
True
>>> # Preserves starting number
>>> doc = NumberedText("5#First\n6#Second")
>>> doc.start == 5
True
>>> # Regular numbered list isn't treated as line numbers
>>> doc = NumberedText("1. First item\n2. Second item")
>>> doc.numbered_lines
['1: 1. First item', '2: 2. Second item']
Source code in src/tnh_scholar/text_processing/numbered_text.py
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
def __init__(
    self, content: Optional[str] = None, start: int = 1, separator: str = ":"
) -> None:
    """
    Initialize a numbered text document, detecting and preserving existing numbering.

    Valid numbered text must have:
    - Sequential line numbers
    - Consistent separator character(s)
    - Every non-empty line must follow the numbering pattern

    Args:
        content: Initial text content, if any
        start: Starting line number (used only if content isn't already numbered)
        separator: Separator between line numbers and content (only if content isn't numbered)

    Examples:
        >>> # Custom separators
        >>> doc = NumberedText("1→First line\\n2→Second line")
        >>> doc.separator == "→"
        True

        >>> # Preserves starting number
        >>> doc = NumberedText("5#First\\n6#Second")
        >>> doc.start == 5
        True

        >>> # Regular numbered list isn't treated as line numbers
        >>> doc = NumberedText("1. First item\\n2. Second item")
        >>> doc.numbered_lines
        ['1: 1. First item', '2: 2. Second item']
    """

    self.lines: List[str] = []  # Declare lines here
    self.start: int = start  # Declare start with its type
    self.separator: str = separator  # and separator

    if not isinstance(content, str):
        raise ValueError("NumberedText requires string input.")

    if start < 1:  # enforce 1 based indexing.
        raise IndexError(
            "NumberedText: Numbered lines must begin on an integer great or equal to 1."
        )

    if not content:
        return

    # Analyze the text format
    is_numbered, detected_sep, start_num = get_numbered_format(content)

    format_info = get_numbered_format(content)

    if format_info.is_numbered:
        self.start = format_info.start_num  # type: ignore
        self.separator = format_info.separator  # type: ignore

        # Extract content by removing number and separator
        pattern = re.compile(rf"^\d+{re.escape(detected_sep)}")
        self.lines = []

        for line in content.splitlines():
            if line.strip():
                self.lines.append(pattern.sub("", line))
            else:
                self.lines.append(line)
    else:
        self.lines = content.splitlines()
        self.start = start
        self.separator = separator
__iter__()

Iterate over (line_number, line_content) pairs.

Source code in src/tnh_scholar/text_processing/numbered_text.py
243
244
245
def __iter__(self) -> Iterator[tuple[int, str]]:
    """Iterate over (line_number, line_content) pairs."""
    return iter((i, line) for i, line in enumerate(self.lines, self.start))
__len__()

Return the number of lines.

Source code in src/tnh_scholar/text_processing/numbered_text.py
239
240
241
def __len__(self) -> int:
    """Return the number of lines."""
    return len(self.lines)
__str__()

Return the numbered text representation.

Source code in src/tnh_scholar/text_processing/numbered_text.py
233
234
235
236
237
def __str__(self) -> str:
    """Return the numbered text representation."""
    return "\n".join(
        self._format_line(i, line) for i, line in enumerate(self.lines, self.start)
    )
append(text)

Append text, splitting into lines if needed.

Source code in src/tnh_scholar/text_processing/numbered_text.py
324
325
326
def append(self, text: str) -> None:
    """Append text, splitting into lines if needed."""
    self.lines.extend(text.splitlines())
from_file(path, **kwargs) classmethod

Create a NumberedText instance from a file.

Source code in src/tnh_scholar/text_processing/numbered_text.py
214
215
216
217
@classmethod
def from_file(cls, path: Path, **kwargs) -> "NumberedText":
    """Create a NumberedText instance from a file."""
    return cls(Path(path).read_text(), **kwargs)
get_line(line_num)

Get content of specified line number.

Source code in src/tnh_scholar/text_processing/numbered_text.py
251
252
253
def get_line(self, line_num: int) -> str:
    """Get content of specified line number."""
    return self[line_num]
get_lines(start, end)

Get content of line range, not inclusive of end line.

Source code in src/tnh_scholar/text_processing/numbered_text.py
263
264
265
def get_lines(self, start: int, end: int) -> List[str]:
    """Get content of line range, not inclusive of end line."""
    return self.lines[self._to_internal_index(start) : self._to_internal_index(end)]
get_numbered_line(line_num)

Get specified line with line number.

Source code in src/tnh_scholar/text_processing/numbered_text.py
258
259
260
261
def get_numbered_line(self, line_num: int) -> str:
    """Get specified line with line number."""
    idx = self._to_line_index(line_num)
    return self._format_line(idx, self[idx])
get_numbered_lines(start, end)
Source code in src/tnh_scholar/text_processing/numbered_text.py
267
268
269
270
271
def get_numbered_lines(self, start: int, end: int) -> List[str]:
    return [
        self._format_line(i + self._to_internal_index(start) + 1, line)
        for i, line in enumerate(self.get_lines(start, end))
    ]
get_numbered_segment(start, end)
Source code in src/tnh_scholar/text_processing/numbered_text.py
310
311
def get_numbered_segment(self, start: int, end: int) -> str:
    return "\n".join(self.get_numbered_lines(start, end))
get_segment(start, end)
Source code in src/tnh_scholar/text_processing/numbered_text.py
273
274
275
276
277
278
279
280
def get_segment(self, start: int, end: int) -> str:
    if start < self.start:
        raise IndexError(f"Start index {start} is before first line {self.start}")
    if end > len(self) + 1:
        raise IndexError(f"End index {end} is past last line {len(self)}")
    if start >= end:
        raise IndexError(f"Start index {start} must be less than end index {end}")
    return "\n".join(self.get_lines(start, end))
insert(line_num, text)

Insert text at specified line number. Assumes text is not empty.

Source code in src/tnh_scholar/text_processing/numbered_text.py
328
329
330
331
332
def insert(self, line_num: int, text: str) -> None:
    """Insert text at specified line number. Assumes text is not empty."""
    new_lines = text.splitlines()
    internal_idx = self._to_internal_index(line_num)
    self.lines[internal_idx:internal_idx] = new_lines
iter_segments(segment_size, min_segment_size=None)

Iterate over segments of the text with specified size.

Parameters:

Name Type Description Default
segment_size int

Number of lines per segment

required
min_segment_size Optional[int]

Optional minimum size for final segment. If specified, last segment will be merged with previous one if it would be smaller than this size.

None

Yields:

Type Description
LineSegment

LineSegment objects containing start and end line numbers

Example

text = NumberedText("line1\nline2\nline3\nline4\nline5") for segment in text.iter_segments(2): ... print(f"Lines {segment.start}-{segment.end}") Lines 1-3 Lines 3-5 Lines 5-6

Source code in src/tnh_scholar/text_processing/numbered_text.py
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
def iter_segments(
    self, segment_size: int, min_segment_size: Optional[int] = None
) -> Iterator[LineSegment]:
    """
    Iterate over segments of the text with specified size.

    Args:
        segment_size: Number of lines per segment
        min_segment_size: Optional minimum size for final segment.
            If specified, last segment will be merged with previous one
            if it would be smaller than this size.

    Yields:
        LineSegment objects containing start and end line numbers

    Example:
        >>> text = NumberedText("line1\\nline2\\nline3\\nline4\\nline5")
        >>> for segment in text.iter_segments(2):
        ...     print(f"Lines {segment.start}-{segment.end}")
        Lines 1-3
        Lines 3-5
        Lines 5-6
    """
    iterator = self.SegmentIterator(
        len(self), segment_size, self.start, min_segment_size
    )
    return iter(iterator)
save(path, numbered=True)

Save document to file.

Parameters:

Name Type Description Default
path Path

Output file path

required
numbered bool

Whether to save with line numbers (default: True)

True
Source code in src/tnh_scholar/text_processing/numbered_text.py
313
314
315
316
317
318
319
320
321
322
def save(self, path: Path, numbered: bool = True) -> None:
    """
    Save document to file.

    Args:
        path: Output file path
        numbered: Whether to save with line numbers (default: True)
    """
    content = str(self) if numbered else "\n".join(self.lines)
    Path(path).write_text(content)
get_numbered_format(text)

Analyze text to determine if it follows a consistent line numbering format.

Valid formats have: - Sequential numbers starting from some value - Consistent separator character(s) - Every line must follow the format

Parameters:

Name Type Description Default
text str

Text to analyze

required

Returns:

Type Description
NumberedFormat

Tuple of (is_numbered, separator, start_number)

Examples:

>>> _analyze_numbered_format("1→First\n2→Second")
(True, "→", 1)
>>> _analyze_numbered_format("1. First")  # Numbered list format
(False, None, None)
>>> _analyze_numbered_format("5#Line\n6#Other")
(True, "#", 5)
Source code in src/tnh_scholar/text_processing/numbered_text.py
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
def get_numbered_format(text: str) -> NumberedFormat:
    """
    Analyze text to determine if it follows a consistent line numbering format.

    Valid formats have:
    - Sequential numbers starting from some value
    - Consistent separator character(s)
    - Every line must follow the format

    Args:
        text: Text to analyze

    Returns:
        Tuple of (is_numbered, separator, start_number)

    Examples:
        >>> _analyze_numbered_format("1→First\\n2→Second")
        (True, "→", 1)
        >>> _analyze_numbered_format("1. First")  # Numbered list format
        (False, None, None)
        >>> _analyze_numbered_format("5#Line\\n6#Other")
        (True, "#", 5)
    """
    if not text.strip():
        return NumberedFormat(False)

    lines = [line for line in text.splitlines() if line.strip()]
    if not lines:
        return NumberedFormat(False)

    # Try to detect pattern from first line
    SEPARATOR_PATTERN = r"[^\w\s.]"  # not (word char or whitespace or period)
    first_match = re.match(rf"^(\d+)({SEPARATOR_PATTERN})(.*?)$", lines[0])

    try:
        return _check_line_structure(first_match, lines)
    except (ValueError, AttributeError):
        return NumberedFormat(False)

simple_section

MatchObject

Bases: BaseModel

Basic Match Object definition.

Source code in src/tnh_scholar/text_processing/simple_section.py
10
11
12
13
14
15
16
17
class MatchObject(BaseModel):
    """Basic Match Object definition."""
    type: str
    level: Optional[int] = None
    words: Optional[List[str]] = None
    case_sensitive: Optional[bool] = False
    decorator: Optional[str] = None
    pattern: Optional[str] = None
case_sensitive = False class-attribute instance-attribute
decorator = None class-attribute instance-attribute
level = None class-attribute instance-attribute
pattern = None class-attribute instance-attribute
type instance-attribute
words = None class-attribute instance-attribute
SectionConfig

Bases: BaseModel

Configuration for section detection.

Source code in src/tnh_scholar/text_processing/simple_section.py
19
20
21
22
23
class SectionConfig(BaseModel):
    """Configuration for section detection."""
    name: str
    description: Optional[str] = None
    patterns: List[MatchObject]
description = None class-attribute instance-attribute
name instance-attribute
patterns instance-attribute
create_text_object(text, boundaries)

Create TextObject from text and section boundaries.

Source code in src/tnh_scholar/text_processing/simple_section.py
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
def create_text_object(text: str, boundaries: List[int]) -> TextObject:
    """Create TextObject from text and section boundaries."""
    lines = text.splitlines()
    sections = []

    # Handle first section starting after line 1
    if not boundaries or boundaries[0] != 1:
        boundaries.insert(0, 1)

    # Create sections from boundaries
    for i in range(len(boundaries)):
        start = boundaries[i]
        end = boundaries[i + 1] - 1 if i + 1 < len(boundaries) else len(lines)

        # Get section title from first line
        title = lines[start - 1].strip()

        section = LogicalSection(
            title=title,
            start_line=start,
            end_line=end
        )
        sections.append(section)

    return TextObject(
        language="en",  # Default to English for PoC
        sections=sections
    )
find_keyword(line, words, case_sensitive, decorator)

Check if line matches keyword pattern.

Source code in src/tnh_scholar/text_processing/simple_section.py
30
31
32
33
34
35
36
37
38
39
40
41
def find_keyword(line: str, words: List[str], case_sensitive: bool, decorator: Optional[str]) -> bool:
    """Check if line matches keyword pattern."""
    if not case_sensitive:
        line = line.lower()
        words = [w.lower() for w in words]

    # Check if line starts with any keyword
    if not any(line.lstrip().startswith(word) for word in words):
        return False

    # If decorator specified, check if it appears in line
    return not decorator or decorator in line
find_markdown_header(line, level)

Check if line matches markdown header pattern.

Source code in src/tnh_scholar/text_processing/simple_section.py
25
26
27
28
def find_markdown_header(line: str, level: int) -> bool:
    """Check if line matches markdown header pattern."""
    stripped = line.lstrip()
    return stripped.startswith('#' * level + ' ')
find_regex(line, pattern)

Check if line matches regex pattern.

Source code in src/tnh_scholar/text_processing/simple_section.py
43
44
45
def find_regex(line: str, pattern: str) -> bool:
    """Check if line matches regex pattern."""
    return bool(re.match(pattern, line))
find_section_boundaries(text, config)

Find all section boundary line numbers.

Source code in src/tnh_scholar/text_processing/simple_section.py
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
def find_section_boundaries(text: str, config: SectionConfig) -> List[int]:
    """Find all section boundary line numbers."""
    boundaries = []

    for i, line in enumerate(text.splitlines(), 1):
        for pattern in config.patterns:
            matched = False

            if pattern.type == "markdown_header" and pattern.level:
                matched = find_markdown_header(line, pattern.level)

            elif pattern.type == "keyword" and pattern.words:
                matched = find_keyword(
                    line, 
                    pattern.words,
                    pattern.case_sensitive or False,
                    pattern.decorator
                )

            elif pattern.type == "regex" and pattern.pattern:
                matched = find_regex(line, pattern.pattern)

            if matched:
                boundaries.append(i)
                break  # Stop checking patterns if we found a match

    return boundaries

text_object

Text object system for managing sectioned content with metadata.

This module provides the core TextObject implementation, handling both internal representation and API interactions. It uses Dublin Core for metadata standards and provides a simplified format for AI service integration.

LogicalSection

Bases: BaseModel

Represents a logical division of text content.

Source code in src/tnh_scholar/text_processing/text_object.py
15
16
17
18
19
20
21
22
23
class LogicalSection(BaseModel):
    """Represents a logical division of text content."""

    start_line: int = Field(..., description="Starting line number of section (inclusive)")
    title: str = Field(..., description="Title describing section content")

    def __lt__(self, other: "LogicalSection") -> bool:
        """Enable sorting by start line."""
        return self.start_line < other.start_line
start_line = Field(..., description='Starting line number of section (inclusive)') class-attribute instance-attribute
title = Field(..., description='Title describing section content') class-attribute instance-attribute
__lt__(other)

Enable sorting by start line.

Source code in src/tnh_scholar/text_processing/text_object.py
21
22
23
def __lt__(self, other: "LogicalSection") -> bool:
    """Enable sorting by start line."""
    return self.start_line < other.start_line
TextMetadata dataclass

Rich metadata container following Dublin Core standards.

Source code in src/tnh_scholar/text_processing/text_object.py
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
@dataclass
class TextMetadata:
    """Rich metadata container following Dublin Core standards."""

    # Core Dublin Core elements
    title: str
    creator: List[str]
    subject: List[str]
    description: str
    publisher: Optional[str] = None
    contributor: List[str] = field(default_factory=list)
    date: Optional[str] = None
    type: str = "Text"
    format: str = "text/plain"
    identifier: Optional[str] = None
    source: Optional[str] = None
    language: str = "en"

    # Additional contextual information
    context: str = field(default="")
    additional_info: Dict[str, Any] = field(default_factory=dict)

    def to_string(self) -> str:
        """Convert metadata to human-readable string format."""
        parts = []

        # Add core elements if present
        if self.title:
            parts.append(f"Title: {self.title}")
        if self.creator:
            parts.append(f"Creator(s): {', '.join(self.creator)}")
        if self.subject:
            parts.append(f"Subject(s): {', '.join(self.subject)}")
        if self.description:
            parts.append(f"Description: {self.description}")
        if self.publisher:
            parts.append(f"Publisher: {self.publisher}")
        if self.contributor:
            parts.append(f"Contributor(s): {', '.join(self.contributor)}")
        if self.date:
            parts.append(f"Date: {self.date}")

        parts.extend((f"Type: {self.type}", f"Format: {self.format}"))

        if self.identifier:
            parts.append(f"Identifier: {self.identifier}")
        if self.source:
            parts.append(f"Source: {self.source}")

        parts.append(f"Language: {self.language}")

        return "\n".join(parts)

    @classmethod
    def from_string(cls, metadata_str: str, context: str = "") -> "TextMetadata":
        """Parse metadata from string representation."""
        fields: Dict[str, Any] = {
            "creator": [],
            "subject": [],
            "contributor": [],
            "context": context
        }

        for line in metadata_str.splitlines():
            if ":" not in line:
                continue

            key, value = line.split(":", 1)
            key = key.strip().lower()
            value = value.strip()

            if key in ["creator(s)", "creator"]:
                fields["creator"] = [c.strip() for c in value.split(",")]
            elif key == "subject(s)" or key == "subject":
                fields["subject"] = [s.strip() for s in value.split(",")]
            elif key == "contributor(s)" or key == "contributor":
                fields["contributor"] = [c.strip() for c in value.split(",")]
            else:
                # Convert key to match dataclass field names
                key = key.replace("(s)", "")
                if key in cls.__dataclass_fields__:
                    fields[key] = value

        return cls(**fields)
additional_info = field(default_factory=dict) class-attribute instance-attribute
context = field(default='') class-attribute instance-attribute
contributor = field(default_factory=list) class-attribute instance-attribute
creator instance-attribute
date = None class-attribute instance-attribute
description instance-attribute
format = 'text/plain' class-attribute instance-attribute
identifier = None class-attribute instance-attribute
language = 'en' class-attribute instance-attribute
publisher = None class-attribute instance-attribute
source = None class-attribute instance-attribute
subject instance-attribute
title instance-attribute
type = 'Text' class-attribute instance-attribute
__init__(title, creator, subject, description, publisher=None, contributor=list(), date=None, type='Text', format='text/plain', identifier=None, source=None, language='en', context='', additional_info=dict())
from_string(metadata_str, context='') classmethod

Parse metadata from string representation.

Source code in src/tnh_scholar/text_processing/text_object.py
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
@classmethod
def from_string(cls, metadata_str: str, context: str = "") -> "TextMetadata":
    """Parse metadata from string representation."""
    fields: Dict[str, Any] = {
        "creator": [],
        "subject": [],
        "contributor": [],
        "context": context
    }

    for line in metadata_str.splitlines():
        if ":" not in line:
            continue

        key, value = line.split(":", 1)
        key = key.strip().lower()
        value = value.strip()

        if key in ["creator(s)", "creator"]:
            fields["creator"] = [c.strip() for c in value.split(",")]
        elif key == "subject(s)" or key == "subject":
            fields["subject"] = [s.strip() for s in value.split(",")]
        elif key == "contributor(s)" or key == "contributor":
            fields["contributor"] = [c.strip() for c in value.split(",")]
        else:
            # Convert key to match dataclass field names
            key = key.replace("(s)", "")
            if key in cls.__dataclass_fields__:
                fields[key] = value

    return cls(**fields)
to_string()

Convert metadata to human-readable string format.

Source code in src/tnh_scholar/text_processing/text_object.py
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
def to_string(self) -> str:
    """Convert metadata to human-readable string format."""
    parts = []

    # Add core elements if present
    if self.title:
        parts.append(f"Title: {self.title}")
    if self.creator:
        parts.append(f"Creator(s): {', '.join(self.creator)}")
    if self.subject:
        parts.append(f"Subject(s): {', '.join(self.subject)}")
    if self.description:
        parts.append(f"Description: {self.description}")
    if self.publisher:
        parts.append(f"Publisher: {self.publisher}")
    if self.contributor:
        parts.append(f"Contributor(s): {', '.join(self.contributor)}")
    if self.date:
        parts.append(f"Date: {self.date}")

    parts.extend((f"Type: {self.type}", f"Format: {self.format}"))

    if self.identifier:
        parts.append(f"Identifier: {self.identifier}")
    if self.source:
        parts.append(f"Source: {self.source}")

    parts.append(f"Language: {self.language}")

    return "\n".join(parts)
TextMetadataFormat

Bases: BaseModel

Simplified metadata format optimized for AI processing.

Source code in src/tnh_scholar/text_processing/text_object.py
25
26
27
28
29
30
31
32
33
34
35
class TextMetadataFormat(BaseModel):
    """Simplified metadata format optimized for AI processing."""

    metadata_summary: str = Field(
        ..., 
        description="Available metadata in human-readable format"
    )
    context: str = Field(
        ...,
        description="Rich contextual information for AI understanding"
    )
context = Field(..., description='Rich contextual information for AI understanding') class-attribute instance-attribute
metadata_summary = Field(..., description='Available metadata in human-readable format') class-attribute instance-attribute
TextObject

Main class for managing sectioned text content with metadata.

Source code in src/tnh_scholar/text_processing/text_object.py
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
class TextObject:
    """Main class for managing sectioned text content with metadata."""

    def __init__(
        self,
        numbered_text: NumberedText,
        language: str,
        sections: List[LogicalSection],
        metadata: TextMetadata
    ) -> None:
        """
        Initialize TextObject with content and metadata.

        Args:
            numbered_text: Text content with line numbering
            language: ISO 639-1 language code
            sections: List of logical sections
            metadata: Dublin Core metadata

        Raises:
            ValueError: If sections are invalid or text is empty
        """
        self.content = numbered_text
        self.language = get_language_code(language)
        self.sections = sorted(sections)
        self.metadata = metadata
        self.total_lines = numbered_text.size
        self._validate()

    def _validate(self) -> None:
        """
        Validate section integrity.

        Raises:
            ValueError: If sections are invalid
        """
        if not self.sections:
            raise ValueError("TextObject must have at least one section")

        # Validate section ordering
        for i, section in enumerate(self.sections):
            if section.start_line < 1:
                raise ValueError(f"Section {i}: start line must be >= 1")
            if section.start_line > self.total_lines:
                raise ValueError(f"Section {i}: start line exceeds text length")
            if i > 0 and section.start_line <= self.sections[i-1].start_line:
                raise ValueError(f"Section {i}: non-sequential start line")

    def get_section_content(self, index: int) -> str:
        """
        Retrieve content for specific section.

        Args:
            index: Section index

        Returns:
            Text content for the section

        Raises:
            IndexError: If index is out of range
        """
        if not 0 <= index < len(self.sections):
            raise IndexError(f"Section index {index} out of range")

        start = self.sections[index].start_line
        end = (self.sections[index + 1].start_line 
               if index < len(self.sections) - 1 
               else self.total_lines + 1)

        return self.content.get_segment(start, end)

    def to_response_format(self) -> TextObjectFormat:
        """
        Convert to API format.

        Returns:
            TextObjectFormat for API interaction
        """
        metadata_summary = self.metadata.to_string()

        return TextObjectFormat(
            metadata=TextMetadataFormat(
                metadata_summary=metadata_summary,
                context=self.metadata.context
            ),
            language=self.language,
            sections=self.sections
        )

    @classmethod
    def from_response_format(
        cls, 
        text: NumberedText,
        response: TextObjectFormat
    ) -> "TextObject":
        """
        Create from API response.

        Args:
            text: Text content
            response: API response format

        Returns:
            New TextObject instance
        """
        metadata = TextMetadata.from_string(
            response.metadata.metadata_summary,
            response.metadata.context
        )

        return cls(
            numbered_text=text,
            language=response.language,
            sections=response.sections,
            metadata=metadata
        )

    def __len__(self) -> int:
        """Return number of sections."""
        return len(self.sections)

    def __iter__(self):
        """Iterate over sections."""
        return iter(self.sections)
content = numbered_text instance-attribute
language = get_language_code(language) instance-attribute
metadata = metadata instance-attribute
sections = sorted(sections) instance-attribute
total_lines = numbered_text.size instance-attribute
__init__(numbered_text, language, sections, metadata)

Initialize TextObject with content and metadata.

Parameters:

Name Type Description Default
numbered_text NumberedText

Text content with line numbering

required
language str

ISO 639-1 language code

required
sections List[LogicalSection]

List of logical sections

required
metadata TextMetadata

Dublin Core metadata

required

Raises:

Type Description
ValueError

If sections are invalid or text is empty

Source code in src/tnh_scholar/text_processing/text_object.py
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
def __init__(
    self,
    numbered_text: NumberedText,
    language: str,
    sections: List[LogicalSection],
    metadata: TextMetadata
) -> None:
    """
    Initialize TextObject with content and metadata.

    Args:
        numbered_text: Text content with line numbering
        language: ISO 639-1 language code
        sections: List of logical sections
        metadata: Dublin Core metadata

    Raises:
        ValueError: If sections are invalid or text is empty
    """
    self.content = numbered_text
    self.language = get_language_code(language)
    self.sections = sorted(sections)
    self.metadata = metadata
    self.total_lines = numbered_text.size
    self._validate()
__iter__()

Iterate over sections.

Source code in src/tnh_scholar/text_processing/text_object.py
250
251
252
def __iter__(self):
    """Iterate over sections."""
    return iter(self.sections)
__len__()

Return number of sections.

Source code in src/tnh_scholar/text_processing/text_object.py
246
247
248
def __len__(self) -> int:
    """Return number of sections."""
    return len(self.sections)
from_response_format(text, response) classmethod

Create from API response.

Parameters:

Name Type Description Default
text NumberedText

Text content

required
response TextObjectFormat

API response format

required

Returns:

Type Description
TextObject

New TextObject instance

Source code in src/tnh_scholar/text_processing/text_object.py
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
@classmethod
def from_response_format(
    cls, 
    text: NumberedText,
    response: TextObjectFormat
) -> "TextObject":
    """
    Create from API response.

    Args:
        text: Text content
        response: API response format

    Returns:
        New TextObject instance
    """
    metadata = TextMetadata.from_string(
        response.metadata.metadata_summary,
        response.metadata.context
    )

    return cls(
        numbered_text=text,
        language=response.language,
        sections=response.sections,
        metadata=metadata
    )
get_section_content(index)

Retrieve content for specific section.

Parameters:

Name Type Description Default
index int

Section index

required

Returns:

Type Description
str

Text content for the section

Raises:

Type Description
IndexError

If index is out of range

Source code in src/tnh_scholar/text_processing/text_object.py
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
def get_section_content(self, index: int) -> str:
    """
    Retrieve content for specific section.

    Args:
        index: Section index

    Returns:
        Text content for the section

    Raises:
        IndexError: If index is out of range
    """
    if not 0 <= index < len(self.sections):
        raise IndexError(f"Section index {index} out of range")

    start = self.sections[index].start_line
    end = (self.sections[index + 1].start_line 
           if index < len(self.sections) - 1 
           else self.total_lines + 1)

    return self.content.get_segment(start, end)
to_response_format()

Convert to API format.

Returns:

Type Description
TextObjectFormat

TextObjectFormat for API interaction

Source code in src/tnh_scholar/text_processing/text_object.py
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
def to_response_format(self) -> TextObjectFormat:
    """
    Convert to API format.

    Returns:
        TextObjectFormat for API interaction
    """
    metadata_summary = self.metadata.to_string()

    return TextObjectFormat(
        metadata=TextMetadataFormat(
            metadata_summary=metadata_summary,
            context=self.metadata.context
        ),
        language=self.language,
        sections=self.sections
    )
TextObjectFormat

Bases: BaseModel

Complete format for API interactions.

Source code in src/tnh_scholar/text_processing/text_object.py
37
38
39
40
41
42
class TextObjectFormat(BaseModel):
    """Complete format for API interactions."""

    metadata: TextMetadataFormat
    language: str = Field(..., description="ISO 639-1 language code")
    sections: List[LogicalSection]
language = Field(..., description='ISO 639-1 language code') class-attribute instance-attribute
metadata instance-attribute
sections instance-attribute

text_processing

clean_text(text, newline=False)

Cleans a given text by replacing specific unwanted characters such as tab, and non-breaking spaces with regular spaces.

This function takes a string as input and applies replacements based on a predefined mapping of characters to replace.

Parameters:

Name Type Description Default
text str

The text to be cleaned.

required

Returns:

Name Type Description
str

The cleaned text with unwanted characters replaced by spaces.

Example

text = "This is\n an example\ttext with\xa0extra spaces." clean_text(text) 'This is an example text with extra spaces.'

Source code in src/tnh_scholar/text_processing/text_processing.py
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
def clean_text(text, newline=False):
    """
    Cleans a given text by replacing specific unwanted characters such as
    tab, and non-breaking spaces with regular spaces.

    This function takes a string as input and applies replacements
    based on a predefined mapping of characters to replace.

    Args:
        text (str): The text to be cleaned.

    Returns:
        str: The cleaned text with unwanted characters replaced by spaces.

    Example:
        >>> text = "This is\\n an example\\ttext with\\xa0extra spaces."
        >>> clean_text(text)
        'This is an example text with extra spaces.'

    """
    # Define a mapping of characters to replace
    replace_map = {
        "\t": " ",  # Replace tabs with space
        "\xa0": " ",  # Replace non-breaking space with regular space
        # Add more replacements as needed
    }

    if newline:
        replace_map["\n"] = ""  # remove newlines

    # Loop through the replace map and replace each character
    for old_char, new_char in replace_map.items():
        text = text.replace(old_char, new_char)

    return text.strip()  # Ensure any leading/trailing spaces are removed
normalize_newlines(text, spacing=2)
Normalize newline blocks in the input text by reducing consecutive newlines
to the specified number of newlines for consistent readability and formatting.

Parameters:
----------
text : str
    The input text containing inconsistent newline spacing.
spacing : int, optional
    The number of newlines to insert between lines. Defaults to 2.

Returns:
-------
str
    The text with consecutive newlines reduced to the specified number of newlines.

Example:
--------
>>> raw_text = "Heading

Paragraph text 1 Paragraph text 2

" >>> normalize_newlines(raw_text, spacing=2) 'Heading

Paragraph text 1

Paragraph text 2

'

Source code in src/tnh_scholar/text_processing/text_processing.py
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
def normalize_newlines(text: str, spacing: int = 2) -> str:
    """
    Normalize newline blocks in the input text by reducing consecutive newlines
    to the specified number of newlines for consistent readability and formatting.

    Parameters:
    ----------
    text : str
        The input text containing inconsistent newline spacing.
    spacing : int, optional
        The number of newlines to insert between lines. Defaults to 2.

    Returns:
    -------
    str
        The text with consecutive newlines reduced to the specified number of newlines.

    Example:
    --------
    >>> raw_text = "Heading\n\n\nParagraph text 1\nParagraph text 2\n\n\n"
    >>> normalize_newlines(raw_text, spacing=2)
    'Heading\n\nParagraph text 1\n\nParagraph text 2\n\n'
    """
    # Replace one or more newlines with the desired number of newlines
    newlines = "\n" * spacing
    return re.sub(r"\n{1,}", newlines, text)

utils

file_utils

FileExistsWarning

Bases: UserWarning

Source code in src/tnh_scholar/utils/file_utils.py
8
9
class FileExistsWarning(UserWarning):
    pass
copy_files_with_regex(source_dir, destination_dir, regex_patterns, preserve_structure=True)

Copies files from subdirectories one level down in the source directory to the destination directory if they match any regex pattern. Optionally preserves the directory structure.

Parameters:

Name Type Description Default
source_dir Path

Path to the source directory to search files in.

required
destination_dir Path

Path to the destination directory where files will be copied.

required
regex_patterns list[str]

List of regex patterns to match file names.

required
preserve_structure bool

Whether to preserve the directory structure. Defaults to True.

True

Raises:

Type Description
ValueError

If the source directory does not exist or is not a directory.

Example

copy_files_with_regex( ... source_dir=Path("/path/to/source"), ... destination_dir=Path("/path/to/destination"), ... regex_patterns=[r'..txt$', r'..log$'], ... preserve_structure=True ... )

Source code in src/tnh_scholar/utils/file_utils.py
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
def copy_files_with_regex(
    source_dir: Path,
    destination_dir: Path,
    regex_patterns: list[str],
    preserve_structure: bool = True,
) -> None:
    """
    Copies files from subdirectories one level down in the source directory to the destination directory
    if they match any regex pattern. Optionally preserves the directory structure.

    Args:
        source_dir (Path): Path to the source directory to search files in.
        destination_dir (Path): Path to the destination directory where files will be copied.
        regex_patterns (list[str]): List of regex patterns to match file names.
        preserve_structure (bool): Whether to preserve the directory structure. Defaults to True.

    Raises:
        ValueError: If the source directory does not exist or is not a directory.

    Example:
        >>> copy_files_with_regex(
        ...     source_dir=Path("/path/to/source"),
        ...     destination_dir=Path("/path/to/destination"),
        ...     regex_patterns=[r'.*\\.txt$', r'.*\\.log$'],
        ...     preserve_structure=True
        ... )
    """
    if not source_dir.is_dir():
        raise ValueError(
            f"The source directory {source_dir} does not exist or is not a directory."
        )

    if not destination_dir.exists():
        destination_dir.mkdir(parents=True, exist_ok=True)

    # Compile regex patterns for efficiency
    compiled_patterns = [re.compile(pattern) for pattern in regex_patterns]

    # Process only one level down
    for subdir in source_dir.iterdir():
        if subdir.is_dir():  # Only process subdirectories
            print(f"processing {subdir}:")
            for file_path in subdir.iterdir():  # Only files in this subdirectory
                if file_path.is_file():
                    print(f"checking file: {file_path.name}")
                    # Check if the file matches any of the regex patterns
                    if any(
                        pattern.match(file_path.name) for pattern in compiled_patterns
                    ):
                        if preserve_structure:
                            # Construct the target path, preserving relative structure
                            relative_path = (
                                subdir.relative_to(source_dir) / file_path.name
                            )
                            target_path = destination_dir / relative_path
                            target_path.parent.mkdir(parents=True, exist_ok=True)
                        else:
                            # Place directly in destination without subdirectory structure
                            target_path = destination_dir / file_path.name

                        shutil.copy2(file_path, target_path)
                        print(f"Copied: {file_path} -> {target_path}")
ensure_directory_exists(dir_path)

Create directory if it doesn't exist.

Parameters:

Name Type Description Default
dir_path Path

Directory path to ensure exists.

required
Source code in src/tnh_scholar/utils/file_utils.py
12
13
14
15
16
17
18
19
20
def ensure_directory_exists(dir_path: Path) -> bool:
    """
    Create directory if it doesn't exist.

    Args:
        dir_path (Path): Directory path to ensure exists.
    """
    # Stub Implementation
    return dir_path.exists()
get_text_from_file(file_path)

Reads the entire content of a text file.

Parameters:

Name Type Description Default
file_path Path

The path to the text file.

required

Returns:

Type Description
str

The content of the text file as a single string.

Source code in src/tnh_scholar/utils/file_utils.py
115
116
117
118
119
120
121
122
123
124
125
126
def get_text_from_file(file_path: Path) -> str:
    """Reads the entire content of a text file.

    Args:
        file_path: The path to the text file.

    Returns:
        The content of the text file as a single string.
    """

    with open(file_path, "r", encoding="utf-8") as file:
        return file.read()
iterate_subdir(directory, recursive=False)

Iterates through subdirectories in the given directory.

Parameters:

Name Type Description Default
directory Path

The root directory to start the iteration.

required
recursive bool

If True, iterates recursively through all subdirectories. If False, iterates only over the immediate subdirectories.

False

Yields:

Name Type Description
Path Path

Paths to each subdirectory.

Example

for subdir in iterate_subdir(Path('/root'), recursive=False): ... print(subdir)

Source code in src/tnh_scholar/utils/file_utils.py
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
def iterate_subdir(
    directory: Path, recursive: bool = False
) -> Generator[Path, None, None]:
    """
    Iterates through subdirectories in the given directory.

    Args:
        directory (Path): The root directory to start the iteration.
        recursive (bool): If True, iterates recursively through all subdirectories.
                          If False, iterates only over the immediate subdirectories.

    Yields:
        Path: Paths to each subdirectory.

    Example:
        >>> for subdir in iterate_subdir(Path('/root'), recursive=False):
        ...     print(subdir)
    """
    if recursive:
        for subdirectory in directory.rglob("*"):
            if subdirectory.is_dir():
                yield subdirectory
    else:
        for subdirectory in directory.iterdir():
            if subdirectory.is_dir():
                yield subdirectory
write_text_to_file(file_path, content, overwrite=False, append=False)

Writes text content to a file, handling overwriting and appending.

Parameters:

Name Type Description Default
file_path Path

The path to the file.

required
content str

The text content to write.

required
overwrite bool

If True, overwrites the file if it exists.

False
append bool

If True, appends the content to the file if it exists.

False

Raises:

Type Description
FileExistsWarning

If the file exists and neither overwrite nor append are True.

Source code in src/tnh_scholar/utils/file_utils.py
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
def write_text_to_file(
    file_path: Path, content: str, overwrite: bool = False, append: bool = False
) -> None:
    """Writes text content to a file, handling overwriting and appending.

    Args:
        file_path: The path to the file.
        content: The text content to write.
        overwrite: If True, overwrites the file if it exists.
        append: If True, appends the content to the file if it exists.

    Raises:
        FileExistsWarning: If the file exists and neither overwrite nor append are True.
    """

    if file_path.exists():
        if not overwrite and not append:
            warnings.warn(
                f"File '{file_path}' already exists. Use overwrite or append.",
                FileExistsWarning,
            )
            return  # Do not write if neither flag is set
        mode = "a" if append else "w"  # Choose mode based on flags
    else:
        mode = "w"  # Default to write mode if file doesn't exist

    with open(file_path, mode, encoding="utf-8") as file:
        file.write(content)

json_utils

format_json(file)

Formats a JSON file with line breaks and indentation for readability.

Parameters:

Name Type Description Default
file Path

Path to the JSON file to be formatted.

required
Example

format_json(Path("data.json"))

Source code in src/tnh_scholar/utils/json_utils.py
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
def format_json(file: Path) -> None:
    """
    Formats a JSON file with line breaks and indentation for readability.

    Args:
        file (Path): Path to the JSON file to be formatted.

    Example:
        format_json(Path("data.json"))
    """
    with file.open("r", encoding="utf-8") as f:
        data = json.load(f)

    with file.open("w", encoding="utf-8") as f:
        json.dump(data, f, indent=4, ensure_ascii=False)
load_json_into_model(file, model)

Loads a JSON file and validates it against a Pydantic model.

Parameters:

Name Type Description Default
file Path

Path to the JSON file.

required
model type[BaseModel]

The Pydantic model to validate against.

required

Returns:

Name Type Description
BaseModel BaseModel

An instance of the validated Pydantic model.

Raises:

Type Description
ValueError

If the file content is invalid JSON or does not match the model.

Example: class ExampleModel(BaseModel): name: str age: int city: str

if __name__ == "__main__":
    json_file = Path("example.json")
    try:
        data = load_json_into_model(json_file, ExampleModel)
        print(data)
    except ValueError as e:
        print(e)
Source code in src/tnh_scholar/utils/json_utils.py
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
def load_json_into_model(file: Path, model: type[BaseModel]) -> BaseModel:
    """
    Loads a JSON file and validates it against a Pydantic model.

    Args:
        file (Path): Path to the JSON file.
        model (type[BaseModel]): The Pydantic model to validate against.

    Returns:
        BaseModel: An instance of the validated Pydantic model.

    Raises:
        ValueError: If the file content is invalid JSON or does not match the model.
    Example:
        class ExampleModel(BaseModel):
        name: str
        age: int
        city: str

        if __name__ == "__main__":
            json_file = Path("example.json")
            try:
                data = load_json_into_model(json_file, ExampleModel)
                print(data)
            except ValueError as e:
                print(e)
    """
    try:
        with file.open("r", encoding="utf-8") as f:
            data = json.load(f)
        return model(**data)
    except (json.JSONDecodeError, ValidationError) as e:
        raise ValueError(f"Error loading or validating JSON file '{file}': {e}")
load_jsonl_to_dict(file_path)

Load a JSONL file into a list of dictionaries.

Parameters:

Name Type Description Default
file_path Path

Path to the JSONL file.

required

Returns:

Type Description
List[Dict]

List[Dict]: A list of dictionaries, each representing a line in the JSONL file.

Example

from pathlib import Path file_path = Path("data.jsonl") data = load_jsonl_to_dict(file_path) print(data) [{'key1': 'value1'}, {'key2': 'value2'}]

Source code in src/tnh_scholar/utils/json_utils.py
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
def load_jsonl_to_dict(file_path: Path) -> List[Dict]:
    """
    Load a JSONL file into a list of dictionaries.

    Args:
        file_path (Path): Path to the JSONL file.

    Returns:
        List[Dict]: A list of dictionaries, each representing a line in the JSONL file.

    Example:
        >>> from pathlib import Path
        >>> file_path = Path("data.jsonl")
        >>> data = load_jsonl_to_dict(file_path)
        >>> print(data)
        [{'key1': 'value1'}, {'key2': 'value2'}]
    """
    with file_path.open("r", encoding="utf-8") as file:
        return [json.loads(line.strip()) for line in file if line.strip()]
save_model_to_json(file, model, indent=4, ensure_ascii=False)

Saves a Pydantic model to a JSON file, formatted with indentation for readability.

Parameters:

Name Type Description Default
file Path

Path to the JSON file where the model will be saved.

required
model BaseModel

The Pydantic model instance to save.

required
indent int

Number of spaces for JSON indentation. Defaults to 4.

4
ensure_ascii bool

Whether to escape non-ASCII characters. Defaults to False.

False

Raises:

Type Description
ValueError

If the model cannot be serialized to JSON.

IOError

If there is an issue writing to the file.

Example

class ExampleModel(BaseModel): name: str age: int

if name == "main": model_instance = ExampleModel(name="John", age=30) json_file = Path("example.json") try: save_model_to_json(json_file, model_instance) print(f"Model saved to {json_file}") except (ValueError, IOError) as e: print(e)

Source code in src/tnh_scholar/utils/json_utils.py
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
def save_model_to_json(
    file: Path, model: BaseModel, indent: int = 4, ensure_ascii: bool = False
) -> None:
    """
    Saves a Pydantic model to a JSON file, formatted with indentation for readability.

    Args:
        file (Path): Path to the JSON file where the model will be saved.
        model (BaseModel): The Pydantic model instance to save.
        indent (int): Number of spaces for JSON indentation. Defaults to 4.
        ensure_ascii (bool): Whether to escape non-ASCII characters. Defaults to False.

    Raises:
        ValueError: If the model cannot be serialized to JSON.
        IOError: If there is an issue writing to the file.

    Example:
        class ExampleModel(BaseModel):
            name: str
            age: int

        if __name__ == "__main__":
            model_instance = ExampleModel(name="John", age=30)
            json_file = Path("example.json")
            try:
                save_model_to_json(json_file, model_instance)
                print(f"Model saved to {json_file}")
            except (ValueError, IOError) as e:
                print(e)
    """
    try:
        # Serialize model to JSON string
        model_dict = model.model_dump()
    except TypeError as e:
        raise ValueError(f"Error serializing model to JSON: {e}")

    # Write the JSON string to the file
    write_data_to_json_file(file, model_dict, indent=indent, ensure_ascii=ensure_ascii)
write_data_to_json_file(file, data, indent=4, ensure_ascii=False)

Writes a dictionary or list as a JSON string to a file, ensuring the parent directory exists, and supports formatting with indentation and ASCII control.

Parameters:

Name Type Description Default
file Path

Path to the JSON file where the data will be written.

required
data Union[dict, list]

The data to write to the file. Typically a dict or list.

required
indent int

Number of spaces for JSON indentation. Defaults to 4.

4
ensure_ascii bool

Whether to escape non-ASCII characters. Defaults to False.

False

Raises:

Type Description
ValueError

If the data cannot be serialized to JSON.

IOError

If there is an issue writing to the file.

Example

from pathlib import Path data = {"key": "value"} write_json_str_to_file(Path("output.json"), data, indent=2, ensure_ascii=True)

Source code in src/tnh_scholar/utils/json_utils.py
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
def write_data_to_json_file(
    file: Path, data: Union[dict, list], indent: int = 4, ensure_ascii: bool = False
) -> None:
    """
    Writes a dictionary or list as a JSON string to a file, ensuring the parent directory exists,
    and supports formatting with indentation and ASCII control.

    Args:
        file (Path): Path to the JSON file where the data will be written.
        data (Union[dict, list]): The data to write to the file. Typically a dict or list.
        indent (int): Number of spaces for JSON indentation. Defaults to 4.
        ensure_ascii (bool): Whether to escape non-ASCII characters. Defaults to False.

    Raises:
        ValueError: If the data cannot be serialized to JSON.
        IOError: If there is an issue writing to the file.

    Example:
        >>> from pathlib import Path
        >>> data = {"key": "value"}
        >>> write_json_str_to_file(Path("output.json"), data, indent=2, ensure_ascii=True)
    """
    try:
        # Convert the data to a formatted JSON string
        json_str = json.dumps(data, indent=indent, ensure_ascii=ensure_ascii)
    except TypeError as e:
        raise ValueError(f"Error serializing data to JSON: {e}")

    try:
        # Ensure the parent directory exists
        file.parent.mkdir(parents=True, exist_ok=True)

        # Write the JSON string to the file
        with file.open("w", encoding="utf-8") as f:
            f.write(json_str)
    except IOError as e:
        raise IOError(f"Error writing JSON string to file '{file}': {e}")

lang

logger = get_child_logger(__name__) module-attribute
get_language_code(text)

Detect the language of the provided text using langdetect.

Parameters:

Name Type Description Default
text str

Text to analyze

      code or 'name' for full English language name
required

Returns:

Name Type Description
str str

return result 'code' ISO 639-1 for detected language.

Raises:

Type Description
ValueError

If text is empty or invalid

Source code in src/tnh_scholar/utils/lang.py
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
def get_language_code(text: str) -> str:
    """
    Detect the language of the provided text using langdetect.

    Args:
        text: Text to analyze

                      code or 'name' for full English language name

    Returns:
        str: return result 'code' ISO 639-1 for detected language.

    Raises:
        ValueError: If text is empty or invalid
    """

    if not text or text.isspace():
        raise ValueError("Input text cannot be empty")

    sample = _get_sample_text(text)

    try:
        return detect(sample)
    except LangDetectException:
        logger.warning("Language could not be detected in get_language().")
        return "un"
get_language_from_code(code)
Source code in src/tnh_scholar/utils/lang.py
40
41
42
43
44
def get_language_from_code(code: str):
    if language := pycountry.languages.get(alpha_2=code):
        return language.name
    logger.warning(f"No language name found for code: {code}")
    return "Unknown"
get_language_name(text)
Source code in src/tnh_scholar/utils/lang.py
36
37
def get_language_name(text: str) -> str:
    return get_language_from_code(get_language_code(text))

progress_utils

BAR_FORMAT = '{desc}: {percentage:3.0f}%|{bar}| Total: {total_fmt} sec. [elapsed: {elapsed}]' module-attribute
ExpectedTimeTQDM

A context manager for a time-based tqdm progress bar with optional delay.

  • 'expected_time': number of seconds we anticipate the task might take.
  • 'display_interval': how often (seconds) to refresh the bar.
  • 'desc': a short description for the bar.
  • 'delay_start': how many seconds to wait (sleep) before we even create/start the bar.

If the task finishes before 'delay_start' has elapsed, the bar may never appear.

Source code in src/tnh_scholar/utils/progress_utils.py
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
class ExpectedTimeTQDM:
    """
    A context manager for a time-based tqdm progress bar with optional delay.

    - 'expected_time': number of seconds we anticipate the task might take.
    - 'display_interval': how often (seconds) to refresh the bar.
    - 'desc': a short description for the bar.
    - 'delay_start': how many seconds to wait (sleep) before we even create/start the bar.

    If the task finishes before 'delay_start' has elapsed, the bar may never appear.
    """

    def __init__(
        self,
        expected_time: float,
        display_interval: float = 0.5,
        desc: str = "Time-based Progress",
        delay_start: float = 1.0,
    ) -> None:
        self.expected_time = round(expected_time)  # use nearest second.
        self.display_interval = display_interval
        self.desc = desc
        self.delay_start = delay_start

        self._stop_event = threading.Event()
        self._pbar = None  # We won't create the bar until after 'delay_start'
        self._start_time = None

    def __enter__(self):
        # Record the start time for reference
        self._start_time = time.time()

        # Spawn the background thread; it will handle waiting and then creating/updating the bar
        self._thread = threading.Thread(target=self._update_bar, daemon=True)
        self._thread.start()

        return self

    def _update_bar(self):
        # 1) Delay so warnings/logs can appear before the bar
        if self.delay_start > 0:
            time.sleep(self.delay_start)

        # 2) Create the tqdm bar (only now does it appear)
        self._pbar = tqdm(
            total=self.expected_time, desc=self.desc, unit="sec", bar_format=BAR_FORMAT
        )

        # 3) Update until told to stop
        while not self._stop_event.is_set():
            elapsed = time.time() - self._start_time
            current_value = min(elapsed, self.expected_time)
            if self._pbar:
                self._pbar.n = round(current_value)
                self._pbar.refresh()
            time.sleep(self.display_interval)

    def __exit__(self, exc_type, exc_value, traceback):
        # Signal the thread to stop
        self._stop_event.set()
        self._thread.join()

        # If the bar was actually created (i.e., we didn't finish too quickly),
        # do a final update and close
        if self._pbar:
            elapsed = time.time() - self._start_time
            self._pbar.n = round(min(elapsed, self.expected_time))
            self._pbar.refresh()
            self._pbar.close()

    import time
delay_start = delay_start instance-attribute
desc = desc instance-attribute
display_interval = display_interval instance-attribute
expected_time = round(expected_time) instance-attribute
__enter__()
Source code in src/tnh_scholar/utils/progress_utils.py
41
42
43
44
45
46
47
48
49
def __enter__(self):
    # Record the start time for reference
    self._start_time = time.time()

    # Spawn the background thread; it will handle waiting and then creating/updating the bar
    self._thread = threading.Thread(target=self._update_bar, daemon=True)
    self._thread.start()

    return self
__exit__(exc_type, exc_value, traceback)
Source code in src/tnh_scholar/utils/progress_utils.py
70
71
72
73
74
75
76
77
78
79
80
81
def __exit__(self, exc_type, exc_value, traceback):
    # Signal the thread to stop
    self._stop_event.set()
    self._thread.join()

    # If the bar was actually created (i.e., we didn't finish too quickly),
    # do a final update and close
    if self._pbar:
        elapsed = time.time() - self._start_time
        self._pbar.n = round(min(elapsed, self.expected_time))
        self._pbar.refresh()
        self._pbar.close()
__init__(expected_time, display_interval=0.5, desc='Time-based Progress', delay_start=1.0)
Source code in src/tnh_scholar/utils/progress_utils.py
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
def __init__(
    self,
    expected_time: float,
    display_interval: float = 0.5,
    desc: str = "Time-based Progress",
    delay_start: float = 1.0,
) -> None:
    self.expected_time = round(expected_time)  # use nearest second.
    self.display_interval = display_interval
    self.desc = desc
    self.delay_start = delay_start

    self._stop_event = threading.Event()
    self._pbar = None  # We won't create the bar until after 'delay_start'
    self._start_time = None
TimeProgress

A context manager for a time-based progress display using dots.

The display updates once per second, printing a dot and showing: - Expected time (if provided) - Elapsed time (always displayed)

Example:

import time with ExpectedTimeProgress(expected_time=60, desc="Transcribing..."): ... time.sleep(5) # Simulate work [Expected Time: 1:00, Elapsed Time: 0:05] .....

Parameters:

Name Type Description Default
expected_time Optional[float]

Expected time in seconds. Optional.

None
display_interval float

How often to print a dot (seconds).

1.0
desc str

Description to display alongside the progress.

''
Source code in src/tnh_scholar/utils/progress_utils.py
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
class TimeProgress:
    """
    A context manager for a time-based progress display using dots.

    The display updates once per second, printing a dot and showing:
    - Expected time (if provided)
    - Elapsed time (always displayed)

    Example:
    >>> import time
    >>> with ExpectedTimeProgress(expected_time=60, desc="Transcribing..."):
    ...     time.sleep(5)  # Simulate work
    [Expected Time: 1:00, Elapsed Time: 0:05] .....

    Args:
        expected_time (Optional[float]): Expected time in seconds. Optional.
        display_interval (float): How often to print a dot (seconds).
        desc (str): Description to display alongside the progress.
    """

    def __init__(
        self,
        expected_time: Optional[float] = None,
        display_interval: float = 1.0,
        desc: str = "",
    ):
        self.expected_time = expected_time
        self.display_interval = display_interval
        self._stop_event = threading.Event()
        self._start_time = None
        self._thread = None
        self.desc = desc
        self._last_length = 0  # To keep track of the last printed line length

    def __enter__(self):
        # Record the start time
        self._start_time = time.time()

        # Spawn the background thread
        self._thread = threading.Thread(target=self._print_progress, daemon=True)
        self._thread.start()

        return self

    def _print_progress(self):
        """
        Continuously prints progress alternating between | and — along with elapsed/expected time.
        """
        symbols = ["|", "/", "—", "\\"]  # Symbols to alternate between
        symbol_index = 0  # Keep track of the current symbol

        while not self._stop_event.is_set():
            elapsed = time.time() - self._start_time

            # Format elapsed time as mm:ss
            elapsed_str = self._format_time(elapsed)

            # Format expected time if provided
            if self.expected_time is not None:
                expected_str = self._format_time(self.expected_time)
                header = f"{self.desc} [Expected Time: {expected_str}, Elapsed Time: {elapsed_str}]"
            else:
                header = f"{self.desc} [Elapsed Time: {elapsed_str}]"

            # Get the current symbol for the spinner
            spinner = symbols[symbol_index]

            # Construct the line with the spinner
            line = f"\r{header} {spinner}"

            # Write to stdout
            sys.stdout.write(line)
            sys.stdout.flush()

            # Update the symbol index to alternate
            symbol_index = (symbol_index + 1) % len(symbols)

            # Sleep before next update
            time.sleep(self.display_interval)

        # Clear the spinner after finishing
        sys.stdout.write("\r" + " " * len(line) + "\r")
        sys.stdout.flush()

    def __exit__(self, exc_type, exc_value, traceback):
        # Signal the thread to stop
        self._stop_event.set()
        self._thread.join()

        # Final elapsed time
        elapsed = time.time() - self._start_time
        elapsed_str = self._format_time(elapsed)

        # Construct the final line
        if self.expected_time is not None:
            expected_str = self._format_time(self.expected_time)
            final_header = f"{self.desc} [Expected Time: {expected_str}, Elapsed Time: {elapsed_str}]"
        else:
            final_header = f"{self.desc} [Elapsed Time: {elapsed_str}]"

        # Final dots
        final_line = f"\r{final_header}"

        # Clear the line and move to the next line
        padding = " " * max(self._last_length - len(final_line), 0)
        sys.stdout.write(final_line + padding + "\n")
        sys.stdout.flush()

    @staticmethod
    def _format_time(seconds: float) -> str:
        """
        Converts seconds to a formatted string (mm:ss).
        """
        minutes = int(seconds // 60)
        seconds = int(seconds % 60)
        return f"{minutes}:{seconds:02}"
desc = desc instance-attribute
display_interval = display_interval instance-attribute
expected_time = expected_time instance-attribute
__enter__()
Source code in src/tnh_scholar/utils/progress_utils.py
122
123
124
125
126
127
128
129
130
def __enter__(self):
    # Record the start time
    self._start_time = time.time()

    # Spawn the background thread
    self._thread = threading.Thread(target=self._print_progress, daemon=True)
    self._thread.start()

    return self
__exit__(exc_type, exc_value, traceback)
Source code in src/tnh_scholar/utils/progress_utils.py
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
def __exit__(self, exc_type, exc_value, traceback):
    # Signal the thread to stop
    self._stop_event.set()
    self._thread.join()

    # Final elapsed time
    elapsed = time.time() - self._start_time
    elapsed_str = self._format_time(elapsed)

    # Construct the final line
    if self.expected_time is not None:
        expected_str = self._format_time(self.expected_time)
        final_header = f"{self.desc} [Expected Time: {expected_str}, Elapsed Time: {elapsed_str}]"
    else:
        final_header = f"{self.desc} [Elapsed Time: {elapsed_str}]"

    # Final dots
    final_line = f"\r{final_header}"

    # Clear the line and move to the next line
    padding = " " * max(self._last_length - len(final_line), 0)
    sys.stdout.write(final_line + padding + "\n")
    sys.stdout.flush()
__init__(expected_time=None, display_interval=1.0, desc='')
Source code in src/tnh_scholar/utils/progress_utils.py
108
109
110
111
112
113
114
115
116
117
118
119
120
def __init__(
    self,
    expected_time: Optional[float] = None,
    display_interval: float = 1.0,
    desc: str = "",
):
    self.expected_time = expected_time
    self.display_interval = display_interval
    self._stop_event = threading.Event()
    self._start_time = None
    self._thread = None
    self.desc = desc
    self._last_length = 0  # To keep track of the last printed line length

slugify

slugify(string)

Slugify a Unicode string.

Converts a string to a strict URL-friendly slug format, allowing only lowercase letters, digits, and hyphens.

Example

slugify("Héllø_Wörld!") 'hello-world'

Source code in src/tnh_scholar/utils/slugify.py
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
def slugify(string: str) -> str:
    """
    Slugify a Unicode string.

    Converts a string to a strict URL-friendly slug format,
    allowing only lowercase letters, digits, and hyphens.

    Example:
        >>> slugify("Héllø_Wörld!")
        'hello-world'
    """
    # Normalize Unicode to remove accents and convert to ASCII
    string = (
        unicodedata.normalize("NFKD", string).encode("ascii", "ignore").decode("ascii")
    )

    # Replace all non-alphanumeric characters with spaces (only keep a-z and 0-9)
    string = re.sub(r"[^a-z0-9\s]", " ", string.lower().strip())

    # Replace any sequence of spaces with a single hyphen
    return re.sub(r"\s+", "-", string)

user_io_utils

get_single_char(prompt=None)

Get a single character from input, adapting to the execution environment.

Parameters:

Name Type Description Default
prompt Optional[str]

Optional prompt to display before getting input

None

Returns:

Type Description
str

A single character string from user input

Note
  • In terminal environments, uses raw input mode without requiring Enter
  • In Jupyter/IPython, falls back to regular input with message about Enter
Source code in src/tnh_scholar/utils/user_io_utils.py
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
def get_single_char(prompt: Optional[str] = None) -> str:
    """
    Get a single character from input, adapting to the execution environment.

    Args:
        prompt: Optional prompt to display before getting input

    Returns:
        A single character string from user input

    Note:
        - In terminal environments, uses raw input mode without requiring Enter
        - In Jupyter/IPython, falls back to regular input with message about Enter
    """
    # Check if we're in IPython/Jupyter
    is_notebook = hasattr(sys, 'ps1') or bool(sys.flags.interactive)

    if prompt:
        print(prompt, end='', flush=True)

    if is_notebook:
        # Jupyter/IPython environment - use regular input
        entry = input("Single character input required ")
        return entry[0] if entry else "\n" # Use newline if no entry

    # Terminal environment
    if os.name == "nt":  # Windows
        import msvcrt
        return msvcrt.getch().decode("utf-8")
    else:  # Unix-like
        import termios
        import tty

        try:
            fd = sys.stdin.fileno()
            old_settings = termios.tcgetattr(fd)
            try:
                tty.setraw(fd)
                char = sys.stdin.read(1)
            finally:
                termios.tcsetattr(fd, termios.TCSADRAIN, old_settings)
            return char
        except termios.error:
            # Fallback if terminal handling fails
            return input("Single character input required ")[0]
get_user_confirmation(prompt, default=True)

Prompt the user for a yes/no confirmation with single-character input. Cross-platform implementation. Returns True if 'y' is entered, and False if 'n' Allows for default value if return is entered.

Example usage if get_user_confirmation("Do you want to continue"): print("Continuing...") else: print("Exiting...")

Source code in src/tnh_scholar/utils/user_io_utils.py
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
def get_user_confirmation(prompt: str, default: bool = True) -> bool:
    """
    Prompt the user for a yes/no confirmation with single-character input.
    Cross-platform implementation. Returns True if 'y' is entered, and False if 'n'
    Allows for default value if return is entered.

    Example usage
        if get_user_confirmation("Do you want to continue"):
            print("Continuing...")
        else:
            print("Exiting...")
    """
    print(f"{prompt} ", end="", flush=True)

    while True:
        char = get_single_char().lower()
        if char == "y":
            print(char)  # Echo the choice
            return True
        elif char == "n":
            print(char)
            return False
        elif char in ("\r", "\n"):  # Enter key (use default)
            print()  # Add a newline
            return default
        else:
            print(
                f"\nInvalid input: {char}. Please type 'y' or 'n': ", end="", flush=True
            )

validate

OCR_ENV_VARS = {'GOOGLE_APPLICATION_CREDENTIALS'} module-attribute
OPENAI_ENV_VARS = {'OPENAI_API_KEY'} module-attribute
logger = get_child_logger(__name__) module-attribute
check_env(required_vars, feature='this feature', output=True)

Check environment variables and provide user-friendly error messages.

Parameters:

Name Type Description Default
required_vars Set[str]

Set of environment variable names to check

required
feature str

Description of feature requiring these variables

'this feature'

Returns:

Name Type Description
bool bool

True if all required variables are set

Source code in src/tnh_scholar/utils/validate.py
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
def check_env(required_vars: Set[str], feature: str = "this feature", output: bool = True) -> bool:
    """
    Check environment variables and provide user-friendly error messages.

    Args:
        required_vars: Set of environment variable names to check
        feature: Description of feature requiring these variables

    Returns:
        bool: True if all required variables are set
    """
    if missing := [var for var in required_vars if not os.getenv(var)]:
        if output:
            message = get_env_message(missing, feature)
            logger.error(f"Missing environment variables: {', '.join(missing)}")
            print(message, file=sys.stderr)
        return False
    return True
check_ocr_env(output=True)

Check OCR processing requirements.

Source code in src/tnh_scholar/utils/validate.py
58
59
60
def check_ocr_env(output: bool = True) -> bool:
    """Check OCR processing requirements."""
    return check_env(OCR_ENV_VARS, "OCR processing", output=output)
check_openai_env(output=True)

Check OpenAI API requirements.

Source code in src/tnh_scholar/utils/validate.py
54
55
56
def check_openai_env(output: bool = True) -> bool:
    """Check OpenAI API requirements."""
    return check_env(OPENAI_ENV_VARS, "OpenAI API access", output=output)
get_env_message(missing_vars, feature='this feature')

Generate user-friendly environment setup message.

Parameters:

Name Type Description Default
missing_vars List[str]

List of missing environment variable names

required
feature str

Name of feature requiring the variables

'this feature'

Returns:

Type Description
str

Formatted error message with setup instructions

Source code in src/tnh_scholar/utils/validate.py
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
def get_env_message(missing_vars: List[str], feature: str = "this feature") -> str:
    """Generate user-friendly environment setup message.

    Args:
        missing_vars: List of missing environment variable names
        feature: Name of feature requiring the variables

    Returns:
        Formatted error message with setup instructions
    """
    export_cmds = " ".join(f"{var}=your_{var.lower()}_here" for var in missing_vars)

    return "\n".join([
        f"\nEnvironment Error: Missing required variables for {feature}:",
        ", ".join(missing_vars),
        "\nSet variables in your shell:",
        f"export {export_cmds}",
        "\nSee documentation for details.",
        "\nFor development: Add to .env file in project root.\n"
    ])

video_processing

video_processing

DEFAULT_TRANSCRIPT_DIR = Path.home() / '.yt_dlp_transcripts' module-attribute
DEFAULT_TRANSCRIPT_OPTIONS = {'skip_download': True, 'quiet': True, 'no_warnings': True, 'extract_flat': True, 'socket_timeout': 30, 'retries': 3, 'ignoreerrors': True, 'logger': logger} module-attribute
logger = get_child_logger(__name__) module-attribute
SubtitleTrack

Bases: TypedDict

Type definition for a subtitle track entry.

Source code in src/tnh_scholar/video_processing/video_processing.py
58
59
60
61
62
63
class SubtitleTrack(TypedDict):
    """Type definition for a subtitle track entry."""

    url: str
    ext: str
    name: str
ext instance-attribute
name instance-attribute
url instance-attribute
TranscriptNotFoundError

Bases: Exception

Raised when no transcript is available for the requested language.

Source code in src/tnh_scholar/video_processing/video_processing.py
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
class TranscriptNotFoundError(Exception):
    """Raised when no transcript is available for the requested language."""

    def __init__(
        self,
        video_url: str,
        language: str,
    ) -> None:
        """
        Initialize TranscriptNotFoundError.

        Args:
            video_url: URL of the video where transcript was not found
            language: Language code that was requested
            available_manual: List of available manual transcript languages
            available_auto: List of available auto-generated transcript languages
        """
        self.video_url = video_url
        self.language = language

        message = (
            f"No transcript found for {self.video_url} in language {self.language}. "
        )
        super().__init__(message)
language = language instance-attribute
video_url = video_url instance-attribute
__init__(video_url, language)

Initialize TranscriptNotFoundError.

Parameters:

Name Type Description Default
video_url str

URL of the video where transcript was not found

required
language str

Language code that was requested

required
available_manual

List of available manual transcript languages

required
available_auto

List of available auto-generated transcript languages

required
Source code in src/tnh_scholar/video_processing/video_processing.py
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
def __init__(
    self,
    video_url: str,
    language: str,
) -> None:
    """
    Initialize TranscriptNotFoundError.

    Args:
        video_url: URL of the video where transcript was not found
        language: Language code that was requested
        available_manual: List of available manual transcript languages
        available_auto: List of available auto-generated transcript languages
    """
    self.video_url = video_url
    self.language = language

    message = (
        f"No transcript found for {self.video_url} in language {self.language}. "
    )
    super().__init__(message)
VideoInfo

Bases: TypedDict

Type definition for relevant video info fields.

Source code in src/tnh_scholar/video_processing/video_processing.py
66
67
68
69
70
class VideoInfo(TypedDict):
    """Type definition for relevant video info fields."""

    subtitles: Dict[str, List[SubtitleTrack]]
    automatic_captions: Dict[str, List[SubtitleTrack]]
automatic_captions instance-attribute
subtitles instance-attribute
download_audio_yt(url, output_dir, start_time=None, prompt_overwrite=True)

Downloads audio from a YouTube video using yt_dlp.YoutubeDL, with an optional start time.

Parameters:

Name Type Description Default
url str

URL of the YouTube video.

required
output_dir Path

Directory to save the downloaded audio file.

required
start_time str

Optional start time (e.g., '00:01:30' for 1 minute 30 seconds).

None

Returns:

Name Type Description
Path Path

Path to the downloaded audio file.

Source code in src/tnh_scholar/video_processing/video_processing.py
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
def download_audio_yt(
    url: str, output_dir: Path, start_time: str = None, prompt_overwrite=True
) -> Path:
    """
    Downloads audio from a YouTube video using yt_dlp.YoutubeDL, with an optional start time.

    Args:
        url (str): URL of the YouTube video.
        output_dir (Path): Directory to save the downloaded audio file.
        start_time (str): Optional start time (e.g., '00:01:30' for 1 minute 30 seconds).

    Returns:
        Path: Path to the downloaded audio file.
    """
    output_dir.mkdir(parents=True, exist_ok=True)
    ydl_opts = {
        "format": "bestaudio/best",
        "postprocessors": [
            {
                "key": "FFmpegExtractAudio",
                "preferredcodec": "mp3",
                "preferredquality": "192",
            }
        ],
        "postprocessor_args": [],
        "noplaylist": True,
        "outtmpl": str(output_dir / "%(title)s.%(ext)s"),
    }

    # Add start time to the FFmpeg postprocessor if provided
    if start_time:
        ydl_opts["postprocessor_args"].extend(["-ss", start_time])
        logger.info(f"Postprocessor start time set to: {start_time}")

    with yt_dlp.YoutubeDL(ydl_opts) as ydl:
        info = ydl.extract_info(url, download=True)  # Extract metadata and download
        filename = ydl.prepare_filename(info)
        return Path(filename).with_suffix(".mp3")
get_transcript(url, lang='en', download_dir=DEFAULT_TRANSCRIPT_DIR, keep_transcript_file=False)

Downloads and extracts the transcript for a given YouTube video URL.

Retrieves the transcript file, extracts the text content, and returns the raw text.

Parameters:

Name Type Description Default
url str

The URL of the YouTube video.

required
lang str

The language code for the transcript (default: 'en').

'en'
download_dir Path

The directory to download the transcript to.

DEFAULT_TRANSCRIPT_DIR
keep_transcript_file bool

Whether to keep the downloaded transcript file (default: False).

False

Returns:

Type Description
str

The extracted transcript text.

Raises:

Type Description
TranscriptNotFoundError

If no transcript is available in the specified language.

DownloadError

If video info extraction or download fails.

ValueError

If the downloaded transcript file is invalid or empty.

ParseError

If XML parsing of the transcript fails.

Source code in src/tnh_scholar/video_processing/video_processing.py
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
def get_transcript(
    url: str,
    lang: str = "en",
    download_dir: Path = DEFAULT_TRANSCRIPT_DIR,
    keep_transcript_file: bool = False,
) -> str:
    """Downloads and extracts the transcript for a given YouTube video URL.

    Retrieves the transcript file, extracts the text content, and returns the raw text.

    Args:
        url: The URL of the YouTube video.
        lang: The language code for the transcript (default: 'en').
        download_dir: The directory to download the transcript to.
        keep_transcript_file: Whether to keep the downloaded transcript file (default: False).

    Returns:
        The extracted transcript text.

    Raises:
        TranscriptNotFoundError: If no transcript is available in the specified language.
        yt_dlp.utils.DownloadError: If video info extraction or download fails.
        ValueError: If the downloaded transcript file is invalid or empty.
        ParseError: If XML parsing of the transcript fails.
    """

    transcript_file = _download_yt_ttml(download_dir, url=url, lang=lang)

    text = get_text_from_file(transcript_file)

    if not keep_transcript_file:
        try:
            os.remove(transcript_file)
            logger.debug(f"Removed temporary transcript file: {transcript_file}")
        except OSError as e:
            logger.warning(
                f"Failed to remove temporary transcript file {transcript_file}: {e}"
            )

    return _extract_ttml_text(text)
get_transcript_info(video_url, lang='en')

Retrieves the transcript URL for a video in the specified language.

Parameters:

Name Type Description Default
video_url str

The URL of the video

required
lang str

The desired language code

'en'

Returns:

Type Description

URL of the transcript

Raises:

Type Description
TranscriptNotFoundError

If no transcript is available in the specified language

DownloadError

If video info extraction fails

Source code in src/tnh_scholar/video_processing/video_processing.py
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
def get_transcript_info(video_url: str, lang: str = "en"):
    """
    Retrieves the transcript URL for a video in the specified language.

    Args:
        video_url: The URL of the video
        lang: The desired language code

    Returns:
        URL of the transcript

    Raises:
        TranscriptNotFoundError: If no transcript is available in the specified language
        yt_dlp.utils.DownloadError: If video info extraction fails
    """
    options = {
        "writesubtitles": True,
        "writeautomaticsub": True,
        "subtitleslangs": [lang],
        "skip_download": True,
        #    'verbose': True
    }

    with yt_dlp.YoutubeDL(options) as ydl:
        # This may raise yt_dlp.utils.DownloadError which we let propagate
        info: VideoInfo = ydl.extract_info(video_url, download=False)  # type: ignore

        subtitles = info.get("subtitles", {})
        auto_subtitles = info.get("automatic_captions", {})

        # Log available subtitle information
        logger.debug("Available subtitles:")
        logger.debug(f"Manual subtitles: {list(subtitles.keys())}")
        logger.debug(f"Auto captions: {list(auto_subtitles.keys())}")

        if lang in subtitles:
            return subtitles[lang][0]["url"]
        elif lang in auto_subtitles:
            return auto_subtitles[lang][0]["url"]

        raise TranscriptNotFoundError(video_url=video_url, language=lang)
get_video_download_path_yt(output_dir, url)

Extracts the video title using yt-dlp.

Parameters:

Name Type Description Default
url str

The YouTube URL.

required

Returns:

Name Type Description
str Path

The title of the video.

Source code in src/tnh_scholar/video_processing/video_processing.py
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
def get_video_download_path_yt(output_dir: Path, url: str) -> Path:
    """
    Extracts the video title using yt-dlp.

    Args:
        url (str): The YouTube URL.

    Returns:
        str: The title of the video.
    """
    ydl_opts = {
        "quiet": True,  # Suppress output
        "skip_download": True,  # Don't download, just fetch metadata
        "outtmpl": str(output_dir / "%(title)s.%(ext)s"),
    }

    with yt_dlp.YoutubeDL(ydl_opts) as ydl:
        info = ydl.extract_info(
            url, download=False
        )  # Extract metadata without downloading
        filepath = ydl.prepare_filename(info)

    return Path(filepath).with_suffix(".mp3")
get_youtube_urls_from_csv(file_path)

Reads a CSV file containing YouTube URLs and titles, logs the titles, and returns a list of URLs.

Parameters:

Name Type Description Default
file_path Path

Path to the CSV file containing YouTube URLs and titles.

required

Returns:

Type Description
List[str]

List[str]: List of YouTube URLs.

Raises:

Type Description
FileNotFoundError

If the file does not exist.

ValueError

If the CSV file is improperly formatted.

Source code in src/tnh_scholar/video_processing/video_processing.py
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
def get_youtube_urls_from_csv(file_path: Path) -> List[str]:
    """
    Reads a CSV file containing YouTube URLs and titles, logs the titles,
    and returns a list of URLs.

    Args:
        file_path (Path): Path to the CSV file containing YouTube URLs and titles.

    Returns:
        List[str]: List of YouTube URLs.

    Raises:
        FileNotFoundError: If the file does not exist.
        ValueError: If the CSV file is improperly formatted.
    """
    if not file_path.exists():
        logger.error(f"File not found: {file_path}")
        raise FileNotFoundError(f"File not found: {file_path}")

    urls = []

    with file_path.open("r", encoding="utf-8") as f:
        reader = csv.DictReader(f)

        if "url" not in reader.fieldnames or "title" not in reader.fieldnames:
            logger.error("CSV file must contain 'url' and 'title' columns.")
            raise ValueError("CSV file must contain 'url' and 'title' columns.")

        for row in reader:
            url = row["url"]
            title = row["title"]
            urls.append(url)
            logger.info(f"Found video title: {title}")

    return urls

yt_transcribe

DEFAULT_CHUNK_DURATION_MS = 10 * 60 * 1000 module-attribute
DEFAULT_CHUNK_DURATION_S = 10 * 60 module-attribute
DEFAULT_OUTPUT_DIR = './video_transcriptions' module-attribute
DEFAULT_PROMPT = 'Dharma, Deer Park, Thay, Thich Nhat Hanh, Bodhicitta, Bodhisattva, Mahayana' module-attribute
EXPECTED_ENV = 'tnh-scholar' module-attribute
args = parser.parse_args() module-attribute
group = parser.add_mutually_exclusive_group(required=True) module-attribute
logger = get_child_logger('yt_transcribe') module-attribute
output_directory = Path(args.output_dir) module-attribute
parser = argparse.ArgumentParser(description='Transcribe YouTube videos from a URL or a file containing URLs.') module-attribute
url_file = Path(args.file) module-attribute
video_urls = [] module-attribute
check_conda_env()
Source code in src/tnh_scholar/video_processing/yt_transcribe.py
31
32
33
34
35
36
37
38
39
def check_conda_env():
    active_env = os.environ.get("CONDA_DEFAULT_ENV")
    if active_env != EXPECTED_ENV:
        logger.warning(
            f"WARNING: The active conda environment is '{active_env}', but '{EXPECTED_ENV}' is required. "
            "Please activate the correct environment."
        )
        # Optionally exit the script
        sys.exit(1)
transcribe_youtube_videos(urls, output_base_dir, max_chunk_duration=DEFAULT_CHUNK_DURATION_S, start=None, translate=False)

Full pipeline for transcribing a list of YouTube videos.

Parameters:

Name Type Description Default
urls list[str]

List of YouTube video URLs.

required
output_base_dir Path

Base directory for storing output.

required
max_chunk_duration int

Maximum duration for audio chunks in seconds (default is 10 minutes).

DEFAULT_CHUNK_DURATION_S
Source code in src/tnh_scholar/video_processing/yt_transcribe.py
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
def transcribe_youtube_videos(
    urls: list[str],
    output_base_dir: Path,
    max_chunk_duration: int = DEFAULT_CHUNK_DURATION_S,
    start: str = None,
    translate=False,
):
    """
    Full pipeline for transcribing a list of YouTube videos.

    Args:
        urls (list[str]): List of YouTube video URLs.
        output_base_dir (Path): Base directory for storing output.
        max_chunk_duration (int): Maximum duration for audio chunks in seconds (default is 10 minutes).
    """
    output_base_dir.mkdir(parents=True, exist_ok=True)

    for url in urls:
        try:
            logger.info(f"Processing video: {url}")

            # Step 1: Download audio
            logger.info("Downloading audio...")
            tmp_audio_file = download_audio_yt(url, output_base_dir, start_time=start)
            logger.info(f"Downloaded audio file: {tmp_audio_file}")

            # Prepare directories for chunks and outputs
            video_name = (
                tmp_audio_file.stem
            )  # Use the stem of the audio file (title without extension)
            video_output_dir = output_base_dir / video_name
            chunks_dir = video_output_dir / "chunks"
            chunks_dir.mkdir(parents=True, exist_ok=True)

            # Create the video directory and move the audio file into it
            video_output_dir.mkdir(parents=True, exist_ok=True)
            audio_file = video_output_dir / tmp_audio_file.name

            try:
                tmp_audio_file.rename(
                    audio_file
                )  # Move the audio file to the video directory
                logger.info(f"Moved audio file to: {audio_file}")
            except Exception as e:
                logger.error(f"Failed to move audio file to {video_output_dir}: {e}")
                # Ensure the code gracefully handles issues here, reassigning to the original tmp path.
                audio_file = tmp_audio_file

            # Step 2: Detect boundaries
            logger.info("Detecting boundaries...")
            boundaries = detect_boundaries(audio_file)
            logger.info("Boundaries generated.")

            # Step 3: Split audio into chunks
            logger.info("Splitting audio into chunks...")
            split_audio_at_boundaries(
                audio_file=audio_file,
                boundaries=boundaries,
                output_dir=chunks_dir,
                max_duration=max_chunk_duration,
            )
            logger.info(f"Audio chunks saved to: {chunks_dir}")

            # Step 4: Transcribe audio chunks
            logger.info("Transcribing audio chunks...")
            transcript_file = video_output_dir / f"{video_name}.txt"
            jsonl_file = video_output_dir / f"{video_name}.jsonl"
            process_audio_chunks(
                directory=chunks_dir,
                output_file=transcript_file,
                jsonl_file=jsonl_file,
                prompt=DEFAULT_PROMPT,
                translate=translate,
            )
            logger.info(f"Transcription completed for {url}")
            logger.info(f"Transcript saved to: {transcript_file}")
            logger.info(f"Raw transcription data saved to: {jsonl_file}")

        except Exception as e:
            logger.error(f"Failed to process video {url}: {e}")

xml_processing

FormattingError

Bases: Exception

Custom exception raised for formatting-related errors.

Source code in src/tnh_scholar/xml_processing/xml_processing.py
 7
 8
 9
10
11
12
13
class FormattingError(Exception):
    """
    Custom exception raised for formatting-related errors.
    """

    def __init__(self, message="An error occurred due to invalid formatting."):
        super().__init__(message)
__init__(message='An error occurred due to invalid formatting.')
Source code in src/tnh_scholar/xml_processing/xml_processing.py
12
13
def __init__(self, message="An error occurred due to invalid formatting."):
    super().__init__(message)

join_xml_data_to_doc(file_path, data, overwrite=False)

Joins a list of XML-tagged data with newlines, wraps it with tags, and writes it to the specified file. Raises an exception if the file exists and overwrite is not set.

Parameters:

Name Type Description Default
file_path Path

Path to the output file.

required
data List[str]

List of XML-tagged data strings.

required
overwrite bool

Whether to overwrite the file if it exists.

False

Raises:

Type Description
FileExistsError

If the file exists and overwrite is False.

ValueError

If the data list is empty.

Example

join_xml_data_to_doc(Path("output.xml"), ["Data"], overwrite=True)

Source code in src/tnh_scholar/xml_processing/xml_processing.py
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
def join_xml_data_to_doc(
    file_path: Path, data: List[str], overwrite: bool = False
) -> None:
    """
    Joins a list of XML-tagged data with newlines, wraps it with <document> tags,
    and writes it to the specified file. Raises an exception if the file exists
    and overwrite is not set.

    Args:
        file_path (Path): Path to the output file.
        data (List[str]): List of XML-tagged data strings.
        overwrite (bool): Whether to overwrite the file if it exists.

    Raises:
        FileExistsError: If the file exists and overwrite is False.
        ValueError: If the data list is empty.

    Example:
        >>> join_xml_data_to_doc(Path("output.xml"), ["<tag>Data</tag>"], overwrite=True)
    """
    if file_path.exists() and not overwrite:
        raise FileExistsError(
            f"The file {file_path} already exists and overwrite is not set."
        )

    if not data:
        raise ValueError("The data list cannot be empty.")

    # Create the XML content
    joined_data = "\n".join(data)  # Joining data with newline
    xml_content = f"<document>\n{joined_data}\n</document>"

    # Write to file
    file_path.write_text(xml_content, encoding="utf-8")

remove_page_tags(text)

Removes and tags from a text string.

Parameters: - text (str): The input text containing tags.

Returns: - str: The text with tags removed.

Source code in src/tnh_scholar/xml_processing/xml_processing.py
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
def remove_page_tags(text):
    """
    Removes <page ...> and </page> tags from a text string.

    Parameters:
    - text (str): The input text containing <page> tags.

    Returns:
    - str: The text with <page> tags removed.
    """
    # Remove opening <page ...> tags
    text = re.sub(r"<page[^>]*>", "", text)
    # Remove closing </page> tags
    text = re.sub(r"</page>", "", text)
    return text

save_pages_to_xml(output_xml_path, text_pages, overwrite=False)

Generates and saves an XML file containing text pages, with a tag indicating the page ends.

Parameters:

Name Type Description Default
output_xml_path Path

The Path object for the file where the XML file will be saved.

required
text_pages List[str]

A list of strings, each representing the text content of a page.

required
overwrite bool

If True, overwrites the file if it exists. Default is False.

False

Returns:

Type Description
None

None

Raises:

Type Description
ValueError

If the input list of text_pages is empty or contains invalid types.

FileExistsError

If the file already exists and overwrite is False.

PermissionError

If the file cannot be created due to insufficient permissions.

OSError

For other file I/O-related errors.

Source code in src/tnh_scholar/xml_processing/xml_processing.py
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
def save_pages_to_xml(
    output_xml_path: Path,
    text_pages: List[str],
    overwrite: bool = False,
) -> None:
    """
    Generates and saves an XML file containing text pages, with a <pagebreak> tag indicating the page ends.

    Parameters:
        output_xml_path (Path): The Path object for the file where the XML file will be saved.
        text_pages (List[str]): A list of strings, each representing the text content of a page.
        overwrite (bool): If True, overwrites the file if it exists. Default is False.

    Returns:
        None

    Raises:
        ValueError: If the input list of text_pages is empty or contains invalid types.
        FileExistsError: If the file already exists and overwrite is False.
        PermissionError: If the file cannot be created due to insufficient permissions.
        OSError: For other file I/O-related errors.
    """
    if not text_pages:
        raise ValueError("The text_pages list is empty. Cannot generate XML.")

    # Check if the file exists and handle overwrite behavior
    if output_xml_path.exists() and not overwrite:
        raise FileExistsError(
            f"The file '{output_xml_path}' already exists. Set overwrite=True to overwrite."
        )

    try:
        # Ensure the output directory exists
        output_xml_path.parent.mkdir(parents=True, exist_ok=True)

        # Write the XML file
        with output_xml_path.open("w", encoding="utf-8") as xml_file:
            # Write XML declaration and root element
            xml_file.write("<?xml version='1.0' encoding='UTF-8'?>\n")
            xml_file.write("<document>\n")

            # Add each page with its content and <pagebreak> tag
            for page_number, text in enumerate(text_pages, start=1):
                if not isinstance(text, str):
                    raise ValueError(
                        f"Invalid page content at index {page_number - 1}: expected a string."
                    )

                content = text.strip()
                escaped_text = escape(content)
                xml_file.write(f"    {escaped_text}\n")
                xml_file.write(f"    <pagebreak page='{page_number}' />\n")

            # Close the root element
            xml_file.write("</document>\n")

        print(f"XML file successfully saved at {output_xml_path}")

    except PermissionError as e:
        raise PermissionError(
            f"Permission denied while writing to {output_xml_path}: {e}"
        ) from e

    except OSError as e:
        raise OSError(
            f"An OS-related error occurred while saving XML file at {output_xml_path}: {e}"
        ) from e

    except Exception as e:
        raise RuntimeError(f"An unexpected error occurred: {e}") from e

split_xml_on_pagebreaks(text, page_groups=None, keep_pagebreaks=True)

Splits an XML document into individual pages based on tags. Optionally groups pages together based on page_groups and retains tags if keep_pagebreaks is True.

Parameters:

Name Type Description Default
text str

The XML document as a string.

required
page_groups Optional[List[Tuple[int, int]]]

A list of tuples defining page ranges to group together. Each tuple is of the form (start_page, end_page), inclusive.

None
keep_pagebreaks bool

Whether to retain the tags in the returned data. Default is False.

True

Returns:

Type Description
List[str]

List[str]: A list of page contents as strings, either split by pages or grouped by page_groups.

Raises:

Type Description
ValueError

If the expected preamble or tags are missing.

Source code in src/tnh_scholar/xml_processing/xml_processing.py
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
def split_xml_on_pagebreaks(
    text: str,
    page_groups: Optional[List[Tuple[int, int]]] = None,
    keep_pagebreaks: bool = True,
) -> List[str]:
    """
    Splits an XML document into individual pages based on <pagebreak> tags.
    Optionally groups pages together based on page_groups and retains <pagebreak> tags if keep_pagebreaks is True.

    Parameters:
        text (str): The XML document as a string.
        page_groups (Optional[List[Tuple[int, int]]]): A list of tuples defining page ranges to group together.
                                                      Each tuple is of the form (start_page, end_page), inclusive.
        keep_pagebreaks (bool): Whether to retain the <pagebreak> tags in the returned data. Default is False.

    Returns:
        List[str]: A list of page contents as strings, either split by pages or grouped by page_groups.

    Raises:
        ValueError: If the expected preamble or <document> tags are missing.
    """
    # Split text into lines
    lines = text.splitlines()

    # Preprocess: Remove `<?xml ... ?>` preamble and <document> tags
    if lines[0].startswith("<?xml"):
        lines.pop(0)
    else:
        raise ValueError("Missing `<?xml ... ?>` preamble on the first line.")
    if lines[0].strip() == "<document>":
        lines.pop(0)
    else:
        raise ValueError("Missing `<document>` opening tag on the second line.")
    if lines[-1].strip() == "</document>":
        lines.pop(-1)
    else:
        raise ValueError("Missing `</document>` closing tag on the last line.")

    # Process content to split pages based on <pagebreak> tags
    pages = []
    current_page = []

    for line in lines:
        if "<pagebreak" in line:  # Page boundary detected
            if current_page:
                page_content = "\n".join(current_page).strip()
                if keep_pagebreaks:
                    page_content += f"\n{line.strip()}"  # Retain the <pagebreak> tag
                pages.append(page_content)
                current_page = []
        else:
            current_page.append(line)

    # Append the last page if it exists
    if current_page:
        pages.append("\n".join(current_page).strip())

    # Validate that pages are extracted
    if not pages:
        raise ValueError("No pages found in the XML content.")

    # Group pages if page_groups is provided
    if page_groups:
        grouped_pages = []
        for start, end in page_groups:
            if group_content := [
                pages[i] for i in range(start - 1, end) if 0 <= i < len(pages)
            ]:
                grouped_pages.append("\n".join(group_content).strip())
        return grouped_pages

    return pages

extract_tags

extract_unique_tags(xml_file)

Extract all unique tags from an XML file using lxml.

Parameters:

Name Type Description Default
xml_file str

Path to the XML file.

required

Returns:

Name Type Description
set

A set of unique tags in the XML document.

Source code in src/tnh_scholar/xml_processing/extract_tags.py
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
def extract_unique_tags(xml_file):
    """
    Extract all unique tags from an XML file using lxml.

    Parameters:
        xml_file (str): Path to the XML file.

    Returns:
        set: A set of unique tags in the XML document.
    """
    # Parse the XML file
    tree = etree.parse(xml_file)

    # Find all unique tags and return
    return {element.tag for element in tree.iter()}
main()
Source code in src/tnh_scholar/xml_processing/extract_tags.py
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
def main():
    # Create argument parser
    parser = argparse.ArgumentParser(
        description="Extract all unique tags from an XML file."
    )
    parser.add_argument("xml_file", type=str, help="Path to the XML file.")

    # Parse command-line arguments
    args = parser.parse_args()

    # Extract tags
    tags = extract_unique_tags(args.xml_file)

    # Print results
    print("Unique Tags Found:")
    for tag in sorted(tags):
        print(tag)

xml_processing

FormattingError

Bases: Exception

Custom exception raised for formatting-related errors.

Source code in src/tnh_scholar/xml_processing/xml_processing.py
 7
 8
 9
10
11
12
13
class FormattingError(Exception):
    """
    Custom exception raised for formatting-related errors.
    """

    def __init__(self, message="An error occurred due to invalid formatting."):
        super().__init__(message)
__init__(message='An error occurred due to invalid formatting.')
Source code in src/tnh_scholar/xml_processing/xml_processing.py
12
13
def __init__(self, message="An error occurred due to invalid formatting."):
    super().__init__(message)
join_xml_data_to_doc(file_path, data, overwrite=False)

Joins a list of XML-tagged data with newlines, wraps it with tags, and writes it to the specified file. Raises an exception if the file exists and overwrite is not set.

Parameters:

Name Type Description Default
file_path Path

Path to the output file.

required
data List[str]

List of XML-tagged data strings.

required
overwrite bool

Whether to overwrite the file if it exists.

False

Raises:

Type Description
FileExistsError

If the file exists and overwrite is False.

ValueError

If the data list is empty.

Example

join_xml_data_to_doc(Path("output.xml"), ["Data"], overwrite=True)

Source code in src/tnh_scholar/xml_processing/xml_processing.py
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
def join_xml_data_to_doc(
    file_path: Path, data: List[str], overwrite: bool = False
) -> None:
    """
    Joins a list of XML-tagged data with newlines, wraps it with <document> tags,
    and writes it to the specified file. Raises an exception if the file exists
    and overwrite is not set.

    Args:
        file_path (Path): Path to the output file.
        data (List[str]): List of XML-tagged data strings.
        overwrite (bool): Whether to overwrite the file if it exists.

    Raises:
        FileExistsError: If the file exists and overwrite is False.
        ValueError: If the data list is empty.

    Example:
        >>> join_xml_data_to_doc(Path("output.xml"), ["<tag>Data</tag>"], overwrite=True)
    """
    if file_path.exists() and not overwrite:
        raise FileExistsError(
            f"The file {file_path} already exists and overwrite is not set."
        )

    if not data:
        raise ValueError("The data list cannot be empty.")

    # Create the XML content
    joined_data = "\n".join(data)  # Joining data with newline
    xml_content = f"<document>\n{joined_data}\n</document>"

    # Write to file
    file_path.write_text(xml_content, encoding="utf-8")
remove_page_tags(text)

Removes and tags from a text string.

Parameters: - text (str): The input text containing tags.

Returns: - str: The text with tags removed.

Source code in src/tnh_scholar/xml_processing/xml_processing.py
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
def remove_page_tags(text):
    """
    Removes <page ...> and </page> tags from a text string.

    Parameters:
    - text (str): The input text containing <page> tags.

    Returns:
    - str: The text with <page> tags removed.
    """
    # Remove opening <page ...> tags
    text = re.sub(r"<page[^>]*>", "", text)
    # Remove closing </page> tags
    text = re.sub(r"</page>", "", text)
    return text
save_pages_to_xml(output_xml_path, text_pages, overwrite=False)

Generates and saves an XML file containing text pages, with a tag indicating the page ends.

Parameters:

Name Type Description Default
output_xml_path Path

The Path object for the file where the XML file will be saved.

required
text_pages List[str]

A list of strings, each representing the text content of a page.

required
overwrite bool

If True, overwrites the file if it exists. Default is False.

False

Returns:

Type Description
None

None

Raises:

Type Description
ValueError

If the input list of text_pages is empty or contains invalid types.

FileExistsError

If the file already exists and overwrite is False.

PermissionError

If the file cannot be created due to insufficient permissions.

OSError

For other file I/O-related errors.

Source code in src/tnh_scholar/xml_processing/xml_processing.py
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
def save_pages_to_xml(
    output_xml_path: Path,
    text_pages: List[str],
    overwrite: bool = False,
) -> None:
    """
    Generates and saves an XML file containing text pages, with a <pagebreak> tag indicating the page ends.

    Parameters:
        output_xml_path (Path): The Path object for the file where the XML file will be saved.
        text_pages (List[str]): A list of strings, each representing the text content of a page.
        overwrite (bool): If True, overwrites the file if it exists. Default is False.

    Returns:
        None

    Raises:
        ValueError: If the input list of text_pages is empty or contains invalid types.
        FileExistsError: If the file already exists and overwrite is False.
        PermissionError: If the file cannot be created due to insufficient permissions.
        OSError: For other file I/O-related errors.
    """
    if not text_pages:
        raise ValueError("The text_pages list is empty. Cannot generate XML.")

    # Check if the file exists and handle overwrite behavior
    if output_xml_path.exists() and not overwrite:
        raise FileExistsError(
            f"The file '{output_xml_path}' already exists. Set overwrite=True to overwrite."
        )

    try:
        # Ensure the output directory exists
        output_xml_path.parent.mkdir(parents=True, exist_ok=True)

        # Write the XML file
        with output_xml_path.open("w", encoding="utf-8") as xml_file:
            # Write XML declaration and root element
            xml_file.write("<?xml version='1.0' encoding='UTF-8'?>\n")
            xml_file.write("<document>\n")

            # Add each page with its content and <pagebreak> tag
            for page_number, text in enumerate(text_pages, start=1):
                if not isinstance(text, str):
                    raise ValueError(
                        f"Invalid page content at index {page_number - 1}: expected a string."
                    )

                content = text.strip()
                escaped_text = escape(content)
                xml_file.write(f"    {escaped_text}\n")
                xml_file.write(f"    <pagebreak page='{page_number}' />\n")

            # Close the root element
            xml_file.write("</document>\n")

        print(f"XML file successfully saved at {output_xml_path}")

    except PermissionError as e:
        raise PermissionError(
            f"Permission denied while writing to {output_xml_path}: {e}"
        ) from e

    except OSError as e:
        raise OSError(
            f"An OS-related error occurred while saving XML file at {output_xml_path}: {e}"
        ) from e

    except Exception as e:
        raise RuntimeError(f"An unexpected error occurred: {e}") from e
split_xml_on_pagebreaks(text, page_groups=None, keep_pagebreaks=True)

Splits an XML document into individual pages based on tags. Optionally groups pages together based on page_groups and retains tags if keep_pagebreaks is True.

Parameters:

Name Type Description Default
text str

The XML document as a string.

required
page_groups Optional[List[Tuple[int, int]]]

A list of tuples defining page ranges to group together. Each tuple is of the form (start_page, end_page), inclusive.

None
keep_pagebreaks bool

Whether to retain the tags in the returned data. Default is False.

True

Returns:

Type Description
List[str]

List[str]: A list of page contents as strings, either split by pages or grouped by page_groups.

Raises:

Type Description
ValueError

If the expected preamble or tags are missing.

Source code in src/tnh_scholar/xml_processing/xml_processing.py
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
def split_xml_on_pagebreaks(
    text: str,
    page_groups: Optional[List[Tuple[int, int]]] = None,
    keep_pagebreaks: bool = True,
) -> List[str]:
    """
    Splits an XML document into individual pages based on <pagebreak> tags.
    Optionally groups pages together based on page_groups and retains <pagebreak> tags if keep_pagebreaks is True.

    Parameters:
        text (str): The XML document as a string.
        page_groups (Optional[List[Tuple[int, int]]]): A list of tuples defining page ranges to group together.
                                                      Each tuple is of the form (start_page, end_page), inclusive.
        keep_pagebreaks (bool): Whether to retain the <pagebreak> tags in the returned data. Default is False.

    Returns:
        List[str]: A list of page contents as strings, either split by pages or grouped by page_groups.

    Raises:
        ValueError: If the expected preamble or <document> tags are missing.
    """
    # Split text into lines
    lines = text.splitlines()

    # Preprocess: Remove `<?xml ... ?>` preamble and <document> tags
    if lines[0].startswith("<?xml"):
        lines.pop(0)
    else:
        raise ValueError("Missing `<?xml ... ?>` preamble on the first line.")
    if lines[0].strip() == "<document>":
        lines.pop(0)
    else:
        raise ValueError("Missing `<document>` opening tag on the second line.")
    if lines[-1].strip() == "</document>":
        lines.pop(-1)
    else:
        raise ValueError("Missing `</document>` closing tag on the last line.")

    # Process content to split pages based on <pagebreak> tags
    pages = []
    current_page = []

    for line in lines:
        if "<pagebreak" in line:  # Page boundary detected
            if current_page:
                page_content = "\n".join(current_page).strip()
                if keep_pagebreaks:
                    page_content += f"\n{line.strip()}"  # Retain the <pagebreak> tag
                pages.append(page_content)
                current_page = []
        else:
            current_page.append(line)

    # Append the last page if it exists
    if current_page:
        pages.append("\n".join(current_page).strip())

    # Validate that pages are extracted
    if not pages:
        raise ValueError("No pages found in the XML content.")

    # Group pages if page_groups is provided
    if page_groups:
        grouped_pages = []
        for start, end in page_groups:
            if group_content := [
                pages[i] for i in range(start - 1, end) if 0 <= i < len(pages)
            ]:
                grouped_pages.append("\n".join(group_content).strip())
        return grouped_pages

    return pages