medcat.utils.relation_extraction.ml_utils
Module Contents
Functions
|
|
|
|
|
|
|
Used by RelCAT.save() and RelCAT.train() |
|
Used by RelCAT.load() and RelCAT.train() |
|
|
|
|
|
This method simply adds the default special tokens that we ecounter. |
|
|
|
Gets to token sequences from the sequence_ouput for the specific token |
Attributes
- medcat.utils.relation_extraction.ml_utils.logger
- medcat.utils.relation_extraction.ml_utils.split_list_train_test_by_class(data, sample_limit=-1, test_size=0.2, shuffle=True)
- Parameters:
data (List) – “output_relations”: relation_instances, <– see create_base_relations_from_doc/csv for data columns
sample_limit (int) – limit the number of samples per class, useful for dataset balancing . Defaults to -1.
test_size (float) – Defaults to 0.2.
shuffle (bool) – shuffle data randomly. Defaults to True.
- Returns:
Tuple[List, List] – train and test datasets
- Return type:
Tuple[List, List]
- medcat.utils.relation_extraction.ml_utils.load_bin_file(file_name, path='./')
- Return type:
Any
- medcat.utils.relation_extraction.ml_utils.save_bin_file(file_name, data, path='./')
- medcat.utils.relation_extraction.ml_utils.save_state(model, optimizer, scheduler, epoch=1, best_f1=0.0, path='./', model_name='BERT', task='train', is_checkpoint=False, final_export=False)
- Used by RelCAT.save() and RelCAT.train()
Saves the RelCAT model state. For checkpointing multiple files are created, best_f1, loss etc. score. If you want to export the model after training set final_export=True and leave is_checkpoint=False.
- Parameters:
model (BaseModel_RelationExtraction) – BertModel_RelationExtraction | LlamaModel_RelationExtraction etc.
optimizer (torch.optim.AdamW, optional) – Defaults to None.
scheduler (torch.optim.lr_scheduler.MultiStepLR, optional) – Defaults to None.
epoch (int) – Defaults to None.
best_f1 (float) – Defaults to None.
path (str) – Defaults to “./”.
model_name (str) – . Defaults to “BERT”. This is used to checkpointing only.
task (str) – Defaults to “train”. This is used to checkpointing only.
is_checkpoint (bool) – Defaults to False.
final_export (bool) – Defaults to False, if True then is_checkpoint must be False also. Exports model.state_dict(), out into”model.dat”.
- Return type:
None
- medcat.utils.relation_extraction.ml_utils.load_state(model, optimizer, scheduler, path='./', model_name='BERT', file_prefix='train', load_best=False, relcat_config=ConfigRelCAT())
Used by RelCAT.load() and RelCAT.train()
- Parameters:
model (BaseModel_RelationExtraction) – BaseModel_RelationExtraction, it has to be initialized before calling this method via (Bert/Llama)Model_RelationExtraction(…)
optimizer (_type_) – optimizer
scheduler (_type_) – scheduler
path (str, optional) – Defaults to “./”.
model_name (str, optional) – Defaults to “BERT”.
file_prefix (str, optional) – Defaults to “train”.
load_best (bool, optional) – Defaults to False.
relcat_config (ConfigRelCAT) – Defaults to ConfigRelCAT().
- Returns:
Tuple (int, int) – last epoch and f1 score.
- Return type:
Tuple[int, int]
- medcat.utils.relation_extraction.ml_utils.save_results(data, model_name='BERT', path='./', file_prefix='train')
- Parameters:
model_name (str) –
path (str) –
file_prefix (str) –
- medcat.utils.relation_extraction.ml_utils.load_results(path, model_name='BERT', file_prefix='train')
- Parameters:
model_name (str) –
file_prefix (str) –
- Return type:
Tuple[List, List, List]
- medcat.utils.relation_extraction.ml_utils.create_tokenizer_pretrain(tokenizer, relcat_config)
This method simply adds the default special tokens that we ecounter.
- Parameters:
tokenizer (BaseTokenizerWrapper_RelationExtraction) – BERT/Llama tokenizer.
relcat_config (ConfigRelCAT) – The RelCAT config.
- Returns:
BaseTokenizerWrapper_RelationExtraction – The same tokenizer.
- Return type:
medcat.utils.relation_extraction.tokenizer.BaseTokenizerWrapper_RelationExtraction
- medcat.utils.relation_extraction.ml_utils.create_dense_layers(relcat_config)
- Parameters:
relcat_config (medcat.config_rel_cat.ConfigRelCAT) –
- medcat.utils.relation_extraction.ml_utils.get_annotation_schema_tag(sequence_output, input_ids, special_tag)
- Gets to token sequences from the sequence_ouput for the specific token
tag ids in self.relcat_config.general.annotation_schema_tag_ids.
- Parameters:
sequence_output (torch.Tensor) – hidden states/embeddings for each token in the input text
input_ids (torch.Tensor) – input token ids
special_tag (List) – special annotation token id pairs
- Returns:
torch.Tensor – new seq_tags
- Return type:
torch.Tensor