:py:mod:`medcat.utils.relation_extraction.ml_utils` =================================================== .. py:module:: medcat.utils.relation_extraction.ml_utils Module Contents --------------- Functions ~~~~~~~~~ .. autoapisummary:: medcat.utils.relation_extraction.ml_utils.split_list_train_test_by_class medcat.utils.relation_extraction.ml_utils.load_bin_file medcat.utils.relation_extraction.ml_utils.save_bin_file medcat.utils.relation_extraction.ml_utils.save_state medcat.utils.relation_extraction.ml_utils.load_state medcat.utils.relation_extraction.ml_utils.save_results medcat.utils.relation_extraction.ml_utils.load_results medcat.utils.relation_extraction.ml_utils.create_tokenizer_pretrain medcat.utils.relation_extraction.ml_utils.create_dense_layers medcat.utils.relation_extraction.ml_utils.get_annotation_schema_tag Attributes ~~~~~~~~~~ .. autoapisummary:: medcat.utils.relation_extraction.ml_utils.logger .. py:data:: logger .. py:function:: split_list_train_test_by_class(data, sample_limit = -1, test_size = 0.2, shuffle = True) :param data: "output_relations": relation_instances, <-- see create_base_relations_from_doc/csv for data columns :type data: List :param sample_limit: limit the number of samples per class, useful for dataset balancing . Defaults to -1. :type sample_limit: int :param test_size: Defaults to 0.2. :type test_size: float :param shuffle: shuffle data randomly. Defaults to True. :type shuffle: bool :Returns: **Tuple[List, List]** -- train and test datasets .. py:function:: load_bin_file(file_name, path='./') .. py:function:: save_bin_file(file_name, data, path='./') .. py:function:: save_state(model, optimizer, scheduler, epoch = 1, best_f1 = 0.0, path = './', model_name = 'BERT', task = 'train', is_checkpoint=False, final_export=False) Used by RelCAT.save() and RelCAT.train() Saves the RelCAT model state. For checkpointing multiple files are created, best_f1, loss etc. score. If you want to export the model after training set final_export=True and leave is_checkpoint=False. :param model: BertModel_RelationExtraction | LlamaModel_RelationExtraction etc. :type model: BaseModel_RelationExtraction :param optimizer: Defaults to None. :type optimizer: torch.optim.AdamW, optional :param scheduler: Defaults to None. :type scheduler: torch.optim.lr_scheduler.MultiStepLR, optional :param epoch: Defaults to None. :type epoch: int :param best_f1: Defaults to None. :type best_f1: float :param path: Defaults to "./". :type path: str :param model_name: . Defaults to "BERT". This is used to checkpointing only. :type model_name: str :param task: Defaults to "train". This is used to checkpointing only. :type task: str :param is_checkpoint: Defaults to False. :type is_checkpoint: bool :param final_export: Defaults to False, if True then is_checkpoint must be False also. Exports model.state_dict(), out into"model.dat". :type final_export: bool .. py:function:: load_state(model, optimizer, scheduler, path='./', model_name='BERT', file_prefix='train', load_best=False, relcat_config = ConfigRelCAT()) Used by RelCAT.load() and RelCAT.train() :param model: BaseModel_RelationExtraction, it has to be initialized before calling this method via (Bert/Llama)Model_RelationExtraction(...) :type model: BaseModel_RelationExtraction :param optimizer: optimizer :type optimizer: _type_ :param scheduler: scheduler :type scheduler: _type_ :param path: Defaults to "./". :type path: str, optional :param model_name: Defaults to "BERT". :type model_name: str, optional :param file_prefix: Defaults to "train". :type file_prefix: str, optional :param load_best: Defaults to False. :type load_best: bool, optional :param relcat_config: Defaults to ConfigRelCAT(). :type relcat_config: ConfigRelCAT :Returns: **Tuple** (*int, int*) -- last epoch and f1 score. .. py:function:: save_results(data, model_name = 'BERT', path = './', file_prefix = 'train') .. py:function:: load_results(path, model_name = 'BERT', file_prefix = 'train') .. py:function:: create_tokenizer_pretrain(tokenizer, relcat_config) This method simply adds the default special tokens that we ecounter. :param tokenizer: BERT/Llama tokenizer. :type tokenizer: BaseTokenizerWrapper_RelationExtraction :param relcat_config: The RelCAT config. :type relcat_config: ConfigRelCAT :Returns: **BaseTokenizerWrapper_RelationExtraction** -- The same tokenizer. .. py:function:: create_dense_layers(relcat_config) .. py:function:: get_annotation_schema_tag(sequence_output, input_ids, special_tag) Gets to token sequences from the sequence_ouput for the specific token tag ids in self.relcat_config.general.annotation_schema_tag_ids. :param sequence_output: hidden states/embeddings for each token in the input text :type sequence_output: torch.Tensor :param input_ids: input token ids :type input_ids: torch.Tensor :param special_tag: special annotation token id pairs :type special_tag: List :Returns: **torch.Tensor** -- new seq_tags