medcat.utils.relation_extraction.pad_seq
Module Contents
Classes
- class medcat.utils.relation_extraction.pad_seq.Pad_Sequence(seq_pad_value, label_pad_value=-1)
- Parameters:
seq_pad_value (int) –
label_pad_value (int) –
- __init__(seq_pad_value, label_pad_value=-1)
- Used in rel_cat.py in RelCAT to create DataLoaders for train/test datasets.
collate_fn for dataloader to collate sequences of different input_ids, ent1/ent2, and label lengths into a fixed length batch. This is applied per batch and not on the whole DataLoader data, padded x sequence, y sequence, x lengths and y lengths of batch.
- Parameters:
seq_pad_value (int) – pad value for input_ids.
label_pad_value (int) – pad value for labels. Defaults to -1.
- __call__(batch)
Pads a batch of input_ids.
- Parameters:
batch (List[torch.Tensor]) – gets the batch of Tensors from RelData.dataset (check __getitem__() method for data returned) and pads the token sequence + labels as needed See https://pytorch.org/docs/stable/_modules/torch/nn/utils/rnn.html#pad_sequence for extra info.
- Returns:
Tuple[Tensor, Tensor, Tensor, LongTensor, LongTensor] – padded data padded input ids, ent1&ent2 start token pos, padded labels, padded input_id_lengths, padded labels length
- Return type:
Tuple[torch.Tensor, List, torch.Tensor, torch.LongTensor, torch.LongTensor]