Skip to main content

snorkelflow_extensions.taxonomy_distillation.models.huggingface.HuggingFaceStudentHTC

class snorkelflow_extensions.taxonomy_distillation.models.huggingface.HuggingFaceStudentHTC(taxonomy: Taxonomy, config: Dict[str, Any], text_encoder: HuggingfaceTextEncoder, encoding_batch_size: int = 128)

Bases: StudentHTC

Hugging Face-based student model for hierarchical text classification.

This implementation combines Hugging Face transformer models for text encoding with multi-layer perceptron (MLP) classifiers for hierarchical classification decisions. The model encodes input text using pre-trained transformers, then applies local classifiers at each taxonomy level to make hierarchical predictions.

The architecture supports configurable MLP layers, various PyTorch optimizers, early stopping, and class weight balancing to handle imbalanced training datasets effectively. Designed for efficient inference after training on teacher-generated labels.

__init__

__init__(taxonomy: Taxonomy, config: Dict[str, Any], text_encoder: HuggingfaceTextEncoder, encoding_batch_size: int = 128) None

Initialize the HuggingFaceStudentHTC model. HuggingFaceStudentHTC models expect a config dictionary with the following keys: {

“model”: {

“hidden_dim”: <hidden dimension of MLP>, “num_layers”: <number of layers in MLP>, “dropout_rate”: <dropout rate of MLP>

}, “optimizer”: {

“name”: <torch optimizer name>, “params”: {<dictionary of kwargs for torch optimizer>}

}, “train”: {

“batch_size”: <batch size>, “num_epochs”: <maximum number of epochs>, “early_stopping_patience”: <number of epochs to wait before early stopping>, “train_validation_split”: <fraction of data to use for training>, “balance_class_weights”: <whether to balance class weights>

}, “valid”: {

“batch_size”: <batch size>

}

} :param taxonomy: The taxonomy data class instance to use. :param config: The configuration dictionary. :param text_encoder: The text encoder to use. :param encoding_batch_size: The batch size to use for encoding texts. Default is 128.

Methods

__init__(taxonomy, config, text_encoder[, ...])Initialize the HuggingFaceStudentHTC model. HuggingFaceStudentHTC models expect a config dictionary with the following keys: { "model": { "hidden_dim": <hidden dimension of MLP>, "num_layers": <number of layers in MLP>, "dropout_rate": <dropout rate of MLP> }, "optimizer": { "name": <torch optimizer name>, "params": {<dictionary of kwargs for torch optimizer>} }, "train": { "batch_size": <batch size>, "num_epochs": <maximum number of epochs>, "early_stopping_patience": <number of epochs to wait before early stopping>, "train_validation_split": <fraction of data to use for training>, "balance_class_weights": <whether to balance class weights> }, "valid": { "batch_size": <batch size> } } :param taxonomy: The taxonomy data class instance to use. :param config: The configuration dictionary. :param text_encoder: The text encoder to use. :param encoding_batch_size: The batch size to use for encoding texts. Default is 128.
classify_text(text)Classify the text.
classify_texts(texts)Classify a list of texts.
fit(teacher_labeled_data)Train the student model local classifiers.
get_property_labeled_data(labeled_data, prop)Get the property classifications for a given property.
get_subtaxonomy_labeled_data(labeled_data, ...)Get the labeled data for a given subtaxonomy.
init_classifiers()Initialize the classifiers. BFS through the taxonomy to initialize the classifiers. classifiers has the following format: .. code-block:: {<class_name>: { "properties": { <property_name>: { "categories": { <class_name>: { "properties": {...} } }, "classifier_index_to_class_name": <list of class names>, "classifier": <classifier> } } }} Returns: None.
load(load_path)Load the student model.
save(save_path)Save the student model.
split_features_labels(features, labels, split)Split the dataset into training and validation datasets.

classify_text

classify_text(text: str) Dict[str, Any]

Classify the text. :param text: The text to classify.

Returns: The classification result.

classify_texts

classify_texts(texts: List[str]) List[Dict[str, Any]]

Classify a list of texts. :param texts: The texts to classify.

Returns: The classification results.

fit

fit(teacher_labeled_data: Dict[str, Any]) Dict[str, Any]

Train the student model local classifiers. Breadth first search traversal through the taxonomy to train the classifiers. Return a dictionary of the training results. :param teacher_labeled_data: The labeled data from the teacher model.

Returns: The training results. The results dictionary has the following format: {

taxonomy.name: {
“properties”: {
<property_name>: {

“training_dataset_size”: <training_dataset_size>, “validation_dataset_size”: <validation_dataset_size>, “training_accuracy”: <training_accuracy>, “training_macro_f1”: <training_macro_f1>, “validation_accuracy”: <validation_accuracy> “validation_macro_f1”: <validation_macro_f1> “categories”: {

<class_name>: {
“properties”: {
<subproperty_name>: {

“training_dataset_size”: <training_dataset_size>, “validation_dataset_size”: <validation_dataset_size>, “training_accuracy”: <training_accuracy>, “training_macro_f1”: <training_macro_f1>, “validation_accuracy”: <validation_accuracy> “validation_macro_f1”: <validation_macro_f1>

}

}, “class_count”: <class_count>,

}

}

}

}, “class_count”: <class_count>,

} “setup_runtime”: <setup_runtime>, “training_runtime”: <training_runtime>

}

load

load(load_path: str) None

Load the student model. :param load_path: The path to load the model.

Returns: None

save

save(save_path: str) None

Save the student model. This method is not implemented. :param save_path: The path to save the model.

Returns: None

split_features_labels

split_features_labels(features: Tensor, labels: Tensor, split: float) Tuple[Tuple[Tensor, Tensor], Tuple[Tensor, Tensor]]

Split the dataset into training and validation datasets. :param features: The features tensor. :param labels: The labels tensor. :param split: The fraction of the data to use for training.

Returns: A tuple of the training and validation datasets.