Source code for aoptk.text_generation_api

from __future__ import annotations
import base64
import os
from itertools import product
from pathlib import Path
import pandas as pd
from dotenv import load_dotenv
from jinja2 import Template
from openai import OpenAI
from aoptk.chemical import Chemical
from aoptk.effect import Effect
from aoptk.find_chemical import FindChemical
from aoptk.literature.convert_image import ConvertImage
from aoptk.literature.convert_pdf_scan import ConvertPDFScan
from aoptk.literature.find_relevant_publication import FindRelevantPublication
from aoptk.normalization.normalize_chemical import NormalizeChemical
from aoptk.relationship_type import Causative
from aoptk.relationship_type import Inhibitive
from aoptk.relationship_type import RelationshipType
from aoptk.relationships.find_relationship import FindRelationship
from aoptk.relationships.relationship import Relationship

[docs] topics = {Inhibitive(), Causative()}
[docs] class LLMFailureError(Exception): """Base class for capturing LLM failures.""" def __init__(self): pass
[docs] class TextGenerationAPI( FindChemical, FindRelationship, NormalizeChemical, ConvertPDFScan, ConvertImage, FindRelevantPublication, ): """Text generation API using OpenAI."""
[docs] role: str = "user"
[docs] temperature: float = 0
[docs] top_p: float = 1
load_dotenv()
[docs] client: None = None
[docs] prompts_dir: Path = Path(__file__).resolve().parent / "prompts"
[docs] chemical_prompt_template: str = "chemical_prompt.txt"
[docs] relationship_text_prompt_template: str = "relationship_text_prompt.txt"
[docs] relationship_text_images_prompt_template: str = "relationship_text_images_prompt.txt"
[docs] relationships_table_prompt_template: str = "relationships_table_prompt.txt"
[docs] normalization_prompt_template: str = "normalization_prompt.txt"
[docs] convert_pdf_scan_prompt_template: str = "convert_pdf_scan_prompt.txt"
[docs] convert_image_prompt_template: str = "convert_image_prompt.txt"
[docs] find_relevant_publications_prompt_template: str = "find_relevant_publications_prompt.txt"
[docs] specification_relationship_text_prompt: str = ""
def __init__( self, model: str = "gpt-oss-120b", url: str = "https://llm.ai.e-infra.cz/v1", api_key: str = os.environ.get("CERIT_API_KEY"), ):
[docs] self.model = model
[docs] self.url = url
[docs] self.api_key = api_key
if self.client is None: self.client = OpenAI( base_url=self.url, api_key=self.api_key, )
[docs] def find_relationships_in_text( self, text: str, chemicals: list[Chemical], effects: list[Effect], relationship_type: RelationshipType, ) -> list[Relationship]: """Find relationships between chemicals and effects. Args: text (str): The input text. chemicals (list[Chemical]): List of chemical entities. effects (list[Effect]): List of effect entities. relationship_type (RelationshipType): The relationship type to classify. """ relationships = [] for chemical, effect in product(chemicals, effects): if (response := self._relationship_prompt(text, chemical, effect, relationship_type)) and ( relationship := self._select_relationship_type(response, relationship_type) ): relationships.append( Relationship(relationship_type=relationship, chemical=chemical, effect=effect, context=text), ) return relationships
[docs] def _relationship_prompt( self, text: str, chemical: Chemical, effect: Effect, relationship_type: RelationshipType, ) -> str: """Classify the relationship between a chemical and an effect. Args: text (str): The input text. chemical (Chemical): The chemical entity. effect (Effect): The effect entity. relationship_type (RelationshipType): The relationship type to classify. """ other_topics = topics.difference({relationship_type}) content = self._render_prompt( self.relationship_text_prompt_template, text=text, chem=chemical.name, effect=effect.name, rel_type=relationship_type, other_topics=", ".join([topic.positive for topic in other_topics]), specification_relationship_text_prompt=self.specification_relationship_text_prompt, ) return self._prompt(content)
[docs] def _render_prompt(self, template_name: str, **context: object) -> str: template_path = self.prompts_dir / template_name with template_path.open(encoding="utf-8") as template_file: template_content = template_file.read() return str(Template(template_content).render(**context))
[docs] def _prompt(self, content: str) -> str: completion = self.client.chat.completions.create( model=self.model, temperature=self.temperature, top_p=self.top_p, messages=[ { "role": self.role, "content": content, }, ], ) if response := completion.choices[0].message.content: return response.strip() raise LLMFailureError
[docs] def _select_relationship_type(self, response: str, relationship_type: RelationshipType) -> str | None: """Select the relationship type based on the response. Args: response (str): The response from the model indicating the relationship type. relationship_type (RelationshipType): The relationship type to classify. """ if response == relationship_type.positive: return relationship_type.positive if response == relationship_type.negative: return relationship_type.negative return None
[docs] def find_chemicals(self, text: str) -> list[Chemical]: """Find chemicals in the given text. Args: text (str): The input text to search for chemicals. """ if response := self._prompt(self._render_prompt(self.chemical_prompt_template, text=text)).lower(): if response == "none": return [] return [Chemical(name=chem.strip().lower()) for chem in response.split(" ; ")] if response.strip() else [] return []
[docs] def _encode_image(self, image_path: str) -> tuple[str, str]: """Encode the image at the given path to a base64 string and return MIME type. Args: image_path (str): The path to the image to encode. Returns: tuple[str, str]: A tuple of (base64_encoded_image, mime_type). """ ext = Path(image_path).suffix.lower() mime_type = f"image/{ext[1:]}" with Path(image_path).open("rb") as image_file: base64_image = base64.b64encode(image_file.read()).decode("utf-8") return base64_image, mime_type
[docs] def _process_colon_separated_response( self, response: str, effect: Effect, relationship_type: RelationshipType, image_path: str, ) -> list[Relationship]: """Process the response from the model that is colon seperated. Args: response (str): The response from the model. effect (Effect): The effect entity. relationship_type (RelationshipType): The relationship type to classify. context (str): The path to the image, used for context in the relationship. image_path (str): The path to the image, used for context in the relationship. """ relationships = [] for raw_line in response.splitlines(): line = raw_line.strip() if " : " not in line: continue chem_name, classification = line.split(" : ", 1) chem_name = chem_name.strip().lower() classification = classification.strip().lower() relationship = self._select_relationship_type(classification, relationship_type) relationships.append( Relationship( relationship_type=relationship, chemical=Chemical(name=chem_name), effect=effect, context=Path(image_path).stem, ), ) return relationships
[docs] def find_relationships_in_table( self, table_df: pd.DataFrame, effects: list[Effect], relationship_type: RelationshipType, ) -> list[Relationship]: """Find relationships between chemicals and effects in a table. Args: table_df (pd.DataFrame): Pandas DataFrame. relationship_type (RelationshipType): The relationship type to classify. effects (list[Effect]): List of effect entities. """ relationships = [] for effect in effects: relationships.extend( self._classify_relationships_in_table( table_df, effect, relationship_type, ), ) return relationships
[docs] def _classify_relationships_in_table( self, table_df: pd.DataFrame, effect: Effect, relationship_type: RelationshipType, ) -> list[Relationship]: """Classify relationships between chemicals and an effect in a table. Args: table_df (pd.DataFrame): Pandas DataFrame. effect (Effect): The effect entity. relationship_type (RelationshipType): The relationship type to classify. Returns: list[Relationship]: List of relationships found in the table. """ table_text = table_df.to_csv(index=False) content = self._render_prompt( self.relationships_table_prompt_template, effect=effect.name, rel_type=relationship_type, table=table_text, ) if response := self._prompt(content): return self._process_colon_separated_response(response, effect, relationship_type, "table") return []
[docs] def normalize_chemical(self, chemical: Chemical, chemical_list: list[Chemical]) -> Chemical: """Normalize the chemical name by finding a matching name in the chemical list. Args: chemical (Chemical): The chemical to normalize. chemical_list (list[Chemical]): The list of chemicals to match against. Returns: Chemical: The normalized chemical. """ if matching_name := self._find_matching_name(chemical, chemical_list): chemical.heading = matching_name return chemical
[docs] def _find_matching_name(self, chemical: Chemical, chemical_list: list[Chemical]) -> Chemical | None: """Find a matching chemical name in the chemical list. Args: chemical (Chemical): The chemical to find a match for. chemical_list (list[Chemical]): The list of chemicals to match against. Returns: Chemical: The matching chemical name, or None if no match is found. """ content = self._render_prompt( self.normalization_prompt_template, chem=chemical.name, list_of_chemical_names="\n".join([chem.name for chem in chemical_list]), ) if response := self._prompt(content).lower(): if response == "none": return chemical.name return response return chemical.name
[docs] def convert_pdf_scan( self, img_base64: str, mime_type: str, ) -> str: """Extract text from a base64-encoded image. Args: img_base64 (str): Base64-encoded image data. mime_type (str): MIME type of the image. Defaults to "image/jpeg". Returns: str: Extracted text from the image. """ content = [ { "type": "text", "text": self._render_prompt(self.convert_pdf_scan_prompt_template), }, {"type": "image_url", "image_url": {"url": f"data:{mime_type};base64,{img_base64.strip()}"}}, ] if response := self._prompt(content): return response return ""
[docs] def find_relationships_in_text_and_images( self, text: str, image_paths: list[str], relationship_type: RelationshipType, effects: list[Effect], ) -> list[Relationship]: """Find relationships between chemicals and effects in the given text and images combined. Args: text (str): The input text. image_paths (list[str]): List of paths to images. relationship_type (RelationshipType): The relationship type to classify. effects (list[Effect]): List of effect entities. """ relationships = [] for effect in effects: relationships.extend( self._classify_relationships_in_text_and_images(text, image_paths, effect, relationship_type), ) return relationships
[docs] def _classify_relationships_in_text_and_images( self, text: str, image_paths: list[str], effect: Effect, relationship_type: RelationshipType, ) -> list[Relationship]: """Classify relationships between chemicals and an effect in the given text and images combined. Args: text (str): The input text. image_paths (list[str]): List of paths to images. effect (Effect): The effect entity. relationship_type (RelationshipType): The relationship type to classify. """ other_topics = topics.difference({relationship_type}) encoded_images = [self._encode_image(image_path) for image_path in image_paths] relationships = [] content = [ { "type": "text", "text": self._render_prompt( self.relationship_text_images_prompt_template, text=text, effect=effect.name, rel_type=relationship_type, other_topics=", ".join([topic.positive for topic in other_topics]), ), }, ] content.extend( { "type": "image_url", "image_url": {"url": f"data:{mime_type};base64,{img}"}, } for img, mime_type in encoded_images ) response = self._prompt(content) if response: relationships.extend( self._process_colon_separated_response( response, effect, relationship_type, "text and images", ), ) return relationships
[docs] def convert_image( self, image_path: str, text: str, ) -> str: """Convert an image to text. Args: image_path (str): Path to the image. text (str): The full text of the publication for context. """ base64_image, mime_type = self._encode_image(image_path) content = [ { "type": "text", "text": self._render_prompt(self.convert_image_prompt_template, text=text), }, {"type": "image_url", "image_url": {"url": f"data:{mime_type};base64,{base64_image}"}}, ] if response := self._prompt(content): return response return ""
[docs] def find_relevant_publications(self, question: str, text: str) -> bool | None: """Answer the question based on a given text. Args: question (str): The question to search for relevant publications. text (str): The extracted text of the publication. """ if response := self._prompt( self._render_prompt(self.find_relevant_publications_prompt_template, question=question, text=text), ).lower(): if response == "yes": return True if response == "no": return False return None