medcat.utils.relation_extraction.pad_seq

Module Contents

Classes

Pad_Sequence

class medcat.utils.relation_extraction.pad_seq.Pad_Sequence(seq_pad_value, label_pad_value=-1)
Parameters:
  • seq_pad_value (int) –

  • label_pad_value (int) –

__init__(seq_pad_value, label_pad_value=-1)
Used in rel_cat.py in RelCAT to create DataLoaders for train/test datasets.

collate_fn for dataloader to collate sequences of different input_ids, ent1/ent2, and label lengths into a fixed length batch. This is applied per batch and not on the whole DataLoader data, padded x sequence, y sequence, x lengths and y lengths of batch.

Parameters:
  • seq_pad_value (int) – pad value for input_ids.

  • label_pad_value (int) – pad value for labels. Defaults to -1.

__call__(batch)

Pads a batch of input_ids.

Parameters:

batch (List[torch.Tensor]) – gets the batch of Tensors from RelData.dataset (check __getitem__() method for data returned) and pads the token sequence + labels as needed See https://pytorch.org/docs/stable/_modules/torch/nn/utils/rnn.html#pad_sequence for extra info.

Returns:

Tuple[Tensor, Tensor, Tensor, LongTensor, LongTensor] – padded data padded input ids, ent1&ent2 start token pos, padded labels, padded input_id_lengths, padded labels length

Return type:

Tuple[torch.Tensor, List, torch.Tensor, torch.LongTensor, torch.LongTensor]