Skip to content

Learning from Crowds

Routines for deep learning for crowds.

CoNAL

Bases: Module

Common Noise Adaptation Layers (CoNAL). This method introduces two types of confusions: worker-specific and global. Each is parameterized by a confusion matrix. The ratio of the two confusions is determined by the common noise adaptation layer. The common noise adaptation layer is a trainable function that takes the instance embedding and the worker ID as input and outputs a scalar value between 0 and 1.

Zhendong Chu, Jing Ma, and Hongning Wang. Learning from Crowds by Modeling Common Confusions. Proceedings of the AAAI Conference on Artificial Intelligence, 35(7), 5832-5840, 2021. https://doi.org/10.1609/aaai.v35i7.16730

Examples:

>>> from crowdkit.learning import CoNAL
>>> import torch
>>> input = torch.randn(3, 5)
>>> workers = torch.tensor([0, 1, 0])
>>> embeddings = torch.randn(3, 5)
>>> conal = CoNAL(5, 2)
>>> conal(embeddings, input, workers)
Source code in crowdkit/learning/conal.py
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
class CoNAL(nn.Module):
    """
    Common Noise Adaptation Layers (CoNAL). This method introduces two types of confusions: worker-specific and
    global. Each is parameterized by a confusion matrix. The ratio of the two confusions is determined by the
    common noise adaptation layer. The common noise adaptation layer is a trainable function that takes the
    instance embedding and the worker ID as input and outputs a scalar value between 0 and 1.

    Zhendong Chu, Jing Ma, and Hongning Wang. Learning from Crowds by Modeling Common Confusions.
    *Proceedings of the AAAI Conference on Artificial Intelligence*, 35(7), 5832-5840, 2021.
    https://doi.org/10.1609/aaai.v35i7.16730

    Examples:
        >>> from crowdkit.learning import CoNAL
        >>> import torch
        >>> input = torch.randn(3, 5)
        >>> workers = torch.tensor([0, 1, 0])
        >>> embeddings = torch.randn(3, 5)
        >>> conal = CoNAL(5, 2)
        >>> conal(embeddings, input, workers)
    """

    def __init__(
        self,
        num_labels: int,
        n_workers: int,
        com_emb_size: int = 20,
        user_feature: Optional[NDArray[np.float32]] = None,
    ):
        """
        Initializes the CoNAL module.

        Args:
            num_labels (int): Number of classes.
            n_workers (int): Number of annotators.
            com_emb_size (int): Embedding size of the common noise module.
            user_feature (np.ndarray): User feature vector.
        """
        super().__init__()
        self.n_workers = n_workers
        self.annotator_confusion_matrices = nn.Parameter(
            _identity_init((n_workers, num_labels, num_labels)),
            requires_grad=True,
        )

        self.common_confusion_matrix = nn.Parameter(
            _identity_init((num_labels, num_labels)), requires_grad=True
        )

        user_feature = user_feature or np.eye(n_workers, dtype=np.float32)
        self.user_feature_vec = nn.Parameter(
            torch.from_numpy(user_feature).float(), requires_grad=False
        )
        self.diff_linear_1 = nn.LazyLinear(128)
        self.diff_linear_2 = nn.Linear(128, com_emb_size)
        self.user_feature_1 = nn.Linear(self.user_feature_vec.size(1), com_emb_size)

    def simple_common_module(
        self, input: torch.Tensor, workers: torch.Tensor
    ) -> torch.Tensor:
        """
        Common noise adoptation module.

        Args:
            input (torch.Tensor): Tensor of shape (batch_size, embedding_size)
            workers (torch.Tensor): Tensor of shape (batch_size,) containing the worker IDs.

        Returns:
            torch.Tensor: Tensor of shape (batch_size, 1) containing the common noise rate.
        """
        instance_difficulty = self.diff_linear_1(input)
        instance_difficulty = self.diff_linear_2(instance_difficulty)

        instance_difficulty = F.normalize(instance_difficulty)
        user_feature = self.user_feature_1(self.user_feature_vec[workers])
        user_feature = F.normalize(user_feature)
        common_rate = torch.sum(instance_difficulty * user_feature, dim=1)
        common_rate = torch.sigmoid(common_rate).unsqueeze(1)
        return common_rate

    def forward(
        self, embeddings: torch.Tensor, logits: torch.Tensor, workers: torch.Tensor
    ) -> torch.Tensor:
        """
        Forward pass of the CoNAL module.

        Args:
            embeddings (torch.Tensor): Tensor of shape (batch_size, embedding_size)
            logits (torch.Tensor): Tensor of shape (batch_size, num_classes)
            workers (torch.Tensor): Tensor of shape (batch_size,) containing the worker IDs.

        Returns:
            torch.Tensor: Tensor of shape (batch_size, 1) containing the predicted output probabilities.
        """
        x = embeddings.view(embeddings.size(0), -1)
        common_rate = self.simple_common_module(x, workers)
        common_prob = torch.einsum(
            "ij,jk->ik", (F.softmax(logits, dim=-1), self.common_confusion_matrix)
        )
        batch_confusion_matrices = self.annotator_confusion_matrices[workers]
        indivi_prob = differentiable_ds(logits, batch_confusion_matrices)
        crowd_out: torch.Tensor = (
            common_rate * common_prob + (1 - common_rate) * indivi_prob
        )  # single instance
        return crowd_out

__init__(num_labels, n_workers, com_emb_size=20, user_feature=None)

Initializes the CoNAL module.

Parameters:

Name Type Description Default
num_labels int

Number of classes.

required
n_workers int

Number of annotators.

required
com_emb_size int

Embedding size of the common noise module.

20
user_feature ndarray

User feature vector.

None
Source code in crowdkit/learning/conal.py
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
def __init__(
    self,
    num_labels: int,
    n_workers: int,
    com_emb_size: int = 20,
    user_feature: Optional[NDArray[np.float32]] = None,
):
    """
    Initializes the CoNAL module.

    Args:
        num_labels (int): Number of classes.
        n_workers (int): Number of annotators.
        com_emb_size (int): Embedding size of the common noise module.
        user_feature (np.ndarray): User feature vector.
    """
    super().__init__()
    self.n_workers = n_workers
    self.annotator_confusion_matrices = nn.Parameter(
        _identity_init((n_workers, num_labels, num_labels)),
        requires_grad=True,
    )

    self.common_confusion_matrix = nn.Parameter(
        _identity_init((num_labels, num_labels)), requires_grad=True
    )

    user_feature = user_feature or np.eye(n_workers, dtype=np.float32)
    self.user_feature_vec = nn.Parameter(
        torch.from_numpy(user_feature).float(), requires_grad=False
    )
    self.diff_linear_1 = nn.LazyLinear(128)
    self.diff_linear_2 = nn.Linear(128, com_emb_size)
    self.user_feature_1 = nn.Linear(self.user_feature_vec.size(1), com_emb_size)

forward(embeddings, logits, workers)

Forward pass of the CoNAL module.

Parameters:

Name Type Description Default
embeddings Tensor

Tensor of shape (batch_size, embedding_size)

required
logits Tensor

Tensor of shape (batch_size, num_classes)

required
workers Tensor

Tensor of shape (batch_size,) containing the worker IDs.

required

Returns:

Type Description
Tensor

torch.Tensor: Tensor of shape (batch_size, 1) containing the predicted output probabilities.

Source code in crowdkit/learning/conal.py
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
def forward(
    self, embeddings: torch.Tensor, logits: torch.Tensor, workers: torch.Tensor
) -> torch.Tensor:
    """
    Forward pass of the CoNAL module.

    Args:
        embeddings (torch.Tensor): Tensor of shape (batch_size, embedding_size)
        logits (torch.Tensor): Tensor of shape (batch_size, num_classes)
        workers (torch.Tensor): Tensor of shape (batch_size,) containing the worker IDs.

    Returns:
        torch.Tensor: Tensor of shape (batch_size, 1) containing the predicted output probabilities.
    """
    x = embeddings.view(embeddings.size(0), -1)
    common_rate = self.simple_common_module(x, workers)
    common_prob = torch.einsum(
        "ij,jk->ik", (F.softmax(logits, dim=-1), self.common_confusion_matrix)
    )
    batch_confusion_matrices = self.annotator_confusion_matrices[workers]
    indivi_prob = differentiable_ds(logits, batch_confusion_matrices)
    crowd_out: torch.Tensor = (
        common_rate * common_prob + (1 - common_rate) * indivi_prob
    )  # single instance
    return crowd_out

simple_common_module(input, workers)

Common noise adoptation module.

Parameters:

Name Type Description Default
input Tensor

Tensor of shape (batch_size, embedding_size)

required
workers Tensor

Tensor of shape (batch_size,) containing the worker IDs.

required

Returns:

Type Description
Tensor

torch.Tensor: Tensor of shape (batch_size, 1) containing the common noise rate.

Source code in crowdkit/learning/conal.py
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
def simple_common_module(
    self, input: torch.Tensor, workers: torch.Tensor
) -> torch.Tensor:
    """
    Common noise adoptation module.

    Args:
        input (torch.Tensor): Tensor of shape (batch_size, embedding_size)
        workers (torch.Tensor): Tensor of shape (batch_size,) containing the worker IDs.

    Returns:
        torch.Tensor: Tensor of shape (batch_size, 1) containing the common noise rate.
    """
    instance_difficulty = self.diff_linear_1(input)
    instance_difficulty = self.diff_linear_2(instance_difficulty)

    instance_difficulty = F.normalize(instance_difficulty)
    user_feature = self.user_feature_1(self.user_feature_vec[workers])
    user_feature = F.normalize(user_feature)
    common_rate = torch.sum(instance_difficulty * user_feature, dim=1)
    common_rate = torch.sigmoid(common_rate).unsqueeze(1)
    return common_rate

CrowdLayer

Bases: Module

CrowdLayer module for classification tasks.

This method applies a worker-specific transformation of the logits. There are four types of transformations: - MW: Multiplication on the worker's confusion matrix. - VW: Element-wise multiplication with the worker's weight vector. - VB: Element-wise addition with the worker's bias vector. - VW + b: Combination of VW and VB: VW * logits + b.

Filipe Rodrigues and Francisco Pereira. Deep Learning from Crowds. Proceedings of the AAAI Conference on Artificial Intelligence, 32(1), 2018. https://doi.org/10.1609/aaai.v32i1.11506

Examples:

>>> from crowdkit.learning import CrowdLayer
>>> import torch
>>> input = torch.randn(3, 5)
>>> workers = torch.tensor([0, 1, 0])
>>> cl = CrowdLayer(5, 2, conn_type="mw")
>>> cl(input, workers)
Source code in crowdkit/learning/crowd_layer.py
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
class CrowdLayer(nn.Module):
    """
    CrowdLayer module for classification tasks.

    This method applies a worker-specific transformation of the logits. There are four types of transformations:
    - MW: Multiplication on the worker's confusion matrix.
    - VW: Element-wise multiplication with the worker's weight vector.
    - VB: Element-wise addition with the worker's bias vector.
    - VW + b: Combination of VW and VB: VW * logits + b.

    Filipe Rodrigues and Francisco Pereira. Deep Learning from Crowds.
    *Proceedings of the AAAI Conference on Artificial Intelligence, 32(1),* 2018.
    https://doi.org/10.1609/aaai.v32i1.11506

    Examples:
        >>> from crowdkit.learning import CrowdLayer
        >>> import torch
        >>> input = torch.randn(3, 5)
        >>> workers = torch.tensor([0, 1, 0])
        >>> cl = CrowdLayer(5, 2, conn_type="mw")
        >>> cl(input, workers)
    """

    def __init__(
        self,
        num_labels: int,
        n_workers: int,
        conn_type: str = "mw",
        device: Optional[torch.device] = None,
        dtype: Optional[torch.dtype] = None,
    ):
        """
        Args:
            num_labels (int): Number of classes.
            n_workers (int): Number of workers.
            conn_type (str): Connection type. One of 'mw', 'vw', 'vb', 'vw+b'.
            device (torch.DeviceObjType): Device to use.
            dtype (torch.dtype): Data type to use.
        Raises:
            ValueError: If conn_type is not one of 'mw', 'vw', 'vb', 'vw+b'.
        """
        super(CrowdLayer, self).__init__()
        self.conn_type = conn_type

        self.n_workers = n_workers
        if conn_type == "mw":
            self.weight = nn.Parameter(
                batch_identity_matrices(
                    n_workers, num_labels, dtype=dtype, device=device
                )
            )
        elif conn_type == "vw":
            self.weight = nn.Parameter(
                torch.ones(n_workers, num_labels, dtype=dtype, device=device)
            )
        elif conn_type == "vb":
            self.weight = nn.Parameter(
                torch.zeros(n_workers, num_labels, dtype=dtype, device=device)
            )
        elif conn_type == "vw+b":
            self.scale = nn.Parameter(
                torch.ones(n_workers, num_labels, dtype=dtype, device=device)
            )
            self.bias = nn.Parameter(
                torch.zeros(n_workers, num_labels, dtype=dtype, device=device)
            )
        else:
            raise ValueError("Unknown connection type for CrowdLayer.")

    def forward(self, outputs: torch.Tensor, workers: torch.Tensor) -> torch.Tensor:
        """
        Forward pass.

        Args:
            outputs (torch.Tensor): Tensor of shape (batch_size, input_dim)
            workers (torch.Tensor): Tensor of shape (batch_size,) containing the worker IDs.

        Returns:
            torch.Tensor: Tensor of shape (batch_size, num_labels)
        """
        if self.conn_type == "mw":
            return crowd_layer_mw(outputs, workers, self.weight)
        elif self.conn_type == "vw":
            return crowd_layer_vw(outputs, workers, self.weight)
        elif self.conn_type == "vb":
            return crowd_layer_vb(outputs, workers, self.weight)
        elif self.conn_type == "vw+b":
            return crowd_layer_vw_b(outputs, workers, self.scale, self.bias)
        else:
            raise ValueError("Unknown connection type for CrowdLayer.")

__init__(num_labels, n_workers, conn_type='mw', device=None, dtype=None)

Parameters:

Name Type Description Default
num_labels int

Number of classes.

required
n_workers int

Number of workers.

required
conn_type str

Connection type. One of 'mw', 'vw', 'vb', 'vw+b'.

'mw'
device DeviceObjType

Device to use.

None
dtype dtype

Data type to use.

None

Raises: ValueError: If conn_type is not one of 'mw', 'vw', 'vb', 'vw+b'.

Source code in crowdkit/learning/crowd_layer.py
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
def __init__(
    self,
    num_labels: int,
    n_workers: int,
    conn_type: str = "mw",
    device: Optional[torch.device] = None,
    dtype: Optional[torch.dtype] = None,
):
    """
    Args:
        num_labels (int): Number of classes.
        n_workers (int): Number of workers.
        conn_type (str): Connection type. One of 'mw', 'vw', 'vb', 'vw+b'.
        device (torch.DeviceObjType): Device to use.
        dtype (torch.dtype): Data type to use.
    Raises:
        ValueError: If conn_type is not one of 'mw', 'vw', 'vb', 'vw+b'.
    """
    super(CrowdLayer, self).__init__()
    self.conn_type = conn_type

    self.n_workers = n_workers
    if conn_type == "mw":
        self.weight = nn.Parameter(
            batch_identity_matrices(
                n_workers, num_labels, dtype=dtype, device=device
            )
        )
    elif conn_type == "vw":
        self.weight = nn.Parameter(
            torch.ones(n_workers, num_labels, dtype=dtype, device=device)
        )
    elif conn_type == "vb":
        self.weight = nn.Parameter(
            torch.zeros(n_workers, num_labels, dtype=dtype, device=device)
        )
    elif conn_type == "vw+b":
        self.scale = nn.Parameter(
            torch.ones(n_workers, num_labels, dtype=dtype, device=device)
        )
        self.bias = nn.Parameter(
            torch.zeros(n_workers, num_labels, dtype=dtype, device=device)
        )
    else:
        raise ValueError("Unknown connection type for CrowdLayer.")

forward(outputs, workers)

Forward pass.

Parameters:

Name Type Description Default
outputs Tensor

Tensor of shape (batch_size, input_dim)

required
workers Tensor

Tensor of shape (batch_size,) containing the worker IDs.

required

Returns:

Type Description
Tensor

torch.Tensor: Tensor of shape (batch_size, num_labels)

Source code in crowdkit/learning/crowd_layer.py
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
def forward(self, outputs: torch.Tensor, workers: torch.Tensor) -> torch.Tensor:
    """
    Forward pass.

    Args:
        outputs (torch.Tensor): Tensor of shape (batch_size, input_dim)
        workers (torch.Tensor): Tensor of shape (batch_size,) containing the worker IDs.

    Returns:
        torch.Tensor: Tensor of shape (batch_size, num_labels)
    """
    if self.conn_type == "mw":
        return crowd_layer_mw(outputs, workers, self.weight)
    elif self.conn_type == "vw":
        return crowd_layer_vw(outputs, workers, self.weight)
    elif self.conn_type == "vb":
        return crowd_layer_vb(outputs, workers, self.weight)
    elif self.conn_type == "vw+b":
        return crowd_layer_vw_b(outputs, workers, self.scale, self.bias)
    else:
        raise ValueError("Unknown connection type for CrowdLayer.")

TextSummarization

Bases: BaseTextsAggregator

Text Aggregation through Summarization

The method uses a pre-trained language model for summarization to aggregate crowdsourced texts. For each task, texts are concateneted by | token and passed as a model's input. If n_permutations is not None, texts are random shuffled n_permutations times and then outputs are aggregated with permutation_aggregator if provided. If permutation_aggregator is not provided, the resulting aggregate is the most common output over permuted inputs.

To use pretrained model and tokenizer from transformers, you need to install torch

M. Orzhenovskii, "Fine-Tuning Pre-Trained Language Model for Crowdsourced Texts Aggregation," Proceedings of the 2nd Crowd Science Workshop: Trust, Ethics, and Excellence in Crowdsourced Data Management at Scale, 2021, pp. 8-14. https://ceur-ws.org/Vol-2932/short1.pdf

S. Pletenev, "Noisy Text Sequences Aggregation as a Summarization Subtask," Proceedings of the 2nd Crowd Science Workshop: Trust, Ethics, and Excellence in Crowdsourced Data Management at Scale, 2021, pp. 15-20. https://ceur-ws.org/Vol-2932/short2.pdf

Examples:

>>> import torch
>>> from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, AutoConfig
>>> from crowdkit.learning import TextSummarization
>>> device = 'cuda' if torch.cuda.is_available() else 'cpu'
>>> mname = "toloka/t5-large-for-text-aggregation"
>>> tokenizer = AutoTokenizer.from_pretrained(mname)
>>> model = AutoModelForSeq2SeqLM.from_pretrained(mname)
>>> agg = TextSummarization(tokenizer, model, device=device)
>>> result = agg.fit_predict(df)
...
Source code in crowdkit/learning/text_summarization.py
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
@attr.s
class TextSummarization(BaseTextsAggregator):
    """Text Aggregation through Summarization

    The method uses a pre-trained language model for summarization to aggregate crowdsourced texts.
    For each task, texts are concateneted by ` | ` token and passed as a model's input. If
    `n_permutations` is not `None`, texts are random shuffled `n_permutations` times and then
    outputs are aggregated with `permutation_aggregator` if provided. If `permutation_aggregator`
    is not provided, the resulting aggregate is the most common output over permuted inputs.

    **To use pretrained model and tokenizer from `transformers`, you need to install [torch](https://pytorch.org/get-started/locally/#start-locally)**

    M. Orzhenovskii,
    "Fine-Tuning Pre-Trained Language Model for Crowdsourced Texts Aggregation,"
    Proceedings of the 2nd Crowd Science Workshop: Trust, Ethics, and Excellence in Crowdsourced Data Management at Scale, 2021, pp. 8-14.
    <https://ceur-ws.org/Vol-2932/short1.pdf>

    S. Pletenev,
    "Noisy Text Sequences Aggregation as a Summarization Subtask,"
    Proceedings of the 2nd Crowd Science Workshop: Trust, Ethics, and Excellence in Crowdsourced Data Management at Scale, 2021, pp. 15-20.
    <https://ceur-ws.org/Vol-2932/short2.pdf>

    Examples:
        >>> import torch
        >>> from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, AutoConfig
        >>> from crowdkit.learning import TextSummarization
        >>> device = 'cuda' if torch.cuda.is_available() else 'cpu'
        >>> mname = "toloka/t5-large-for-text-aggregation"
        >>> tokenizer = AutoTokenizer.from_pretrained(mname)
        >>> model = AutoModelForSeq2SeqLM.from_pretrained(mname)
        >>> agg = TextSummarization(tokenizer, model, device=device)
        >>> result = agg.fit_predict(df)
        ...
    """

    tokenizer: PreTrainedTokenizer = attr.ib()
    """[Pre-trained tokenizer](https://huggingface.co/transformers/main_classes/tokenizer.html#pretrainedtokenizer)."""

    model: PreTrainedModel = attr.ib()
    """[Pre-trained model](https://huggingface.co/transformers/main_classes/model.html#pretrainedmodel) for text summarization."""

    concat_token: str = attr.ib(default=" | ")
    """Token used for the workers' texts concatenation."""

    num_beams: int = attr.ib(default=16)
    """Number of beams for beam search. 1 means no beam search."""

    n_permutations: Optional[int] = attr.ib(default=None)
    """Number of input permutations to use. If `None`, use a single permutation according to the input's order."""

    permutation_aggregator: Optional[BaseTextsAggregator] = attr.ib(default=None)
    """Text aggregation method to use for aggregating outputs of multiple input permutations if `use_permutations` flag is set."""

    device: str = attr.ib(default="cpu")
    """Device to use such as `cpu` or `cuda`."""

    def fit_predict(self, data: pd.DataFrame) -> "pd.Series[Any]":
        """Run the aggregation and return the aggregated texts.
        Args:
            data (DataFrame): Workers' text outputs.
                A pandas.DataFrame containing `task`, `worker` and `text` columns.
        Returns:
            Series: Tasks' texts.
                A pandas.Series indexed by `task` such that `result.loc[task, text]`
                is the task's text.
        """

        data = data[["task", "worker", "text"]]

        self.model = self.model.to(self.device)
        self.texts_ = data.groupby("task")["text"].apply(self._aggregate_one)
        return self.texts_

    def _aggregate_one(self, outputs: "pd.Series[Any]") -> str:
        if not self.n_permutations:
            return self._generate_output(outputs)

        generated_outputs = []

        # TODO: generate only `n_permutations` permutations
        permutations = list(itertools.permutations(outputs))
        permutations_idx = np.random.choice(
            len(permutations), size=self.n_permutations, replace=False
        )
        permutations = [permutations[i] for i in permutations_idx]
        for permutation in permutations:
            generated_outputs.append(self._generate_output(permutation))

        data = pd.DataFrame(
            {"task": [""] * len(generated_outputs), "text": generated_outputs}
        )

        if self.permutation_aggregator is not None:
            return cast(str, self.permutation_aggregator.fit_predict(data)[""])

        return cast(str, data.text.mode())

    def _generate_output(
        self, permutation: Union[Iterable[Any], "pd.Series[Any]"]
    ) -> str:
        input_text = self.concat_token.join(permutation)
        input_ids = self.tokenizer.encode(input_text, return_tensors="pt").to(
            self.device
        )
        outputs = self.model.generate(input_ids, num_beams=self.num_beams)
        return cast(str, self.tokenizer.decode(outputs[0], skip_special_tokens=True))

concat_token: str = attr.ib(default=' | ') class-attribute instance-attribute

Token used for the workers' texts concatenation.

device: str = attr.ib(default='cpu') class-attribute instance-attribute

Device to use such as cpu or cuda.

model: PreTrainedModel = attr.ib() class-attribute instance-attribute

Pre-trained model for text summarization.

n_permutations: Optional[int] = attr.ib(default=None) class-attribute instance-attribute

Number of input permutations to use. If None, use a single permutation according to the input's order.

num_beams: int = attr.ib(default=16) class-attribute instance-attribute

Number of beams for beam search. 1 means no beam search.

permutation_aggregator: Optional[BaseTextsAggregator] = attr.ib(default=None) class-attribute instance-attribute

Text aggregation method to use for aggregating outputs of multiple input permutations if use_permutations flag is set.

tokenizer: PreTrainedTokenizer = attr.ib() class-attribute instance-attribute

fit_predict(data)

Run the aggregation and return the aggregated texts. Args: data (DataFrame): Workers' text outputs. A pandas.DataFrame containing task, worker and text columns. Returns: Series: Tasks' texts. A pandas.Series indexed by task such that result.loc[task, text] is the task's text.

Source code in crowdkit/learning/text_summarization.py
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
def fit_predict(self, data: pd.DataFrame) -> "pd.Series[Any]":
    """Run the aggregation and return the aggregated texts.
    Args:
        data (DataFrame): Workers' text outputs.
            A pandas.DataFrame containing `task`, `worker` and `text` columns.
    Returns:
        Series: Tasks' texts.
            A pandas.Series indexed by `task` such that `result.loc[task, text]`
            is the task's text.
    """

    data = data[["task", "worker", "text"]]

    self.model = self.model.to(self.device)
    self.texts_ = data.groupby("task")["text"].apply(self._aggregate_one)
    return self.texts_