snorkelflow_extensions.taxonomy_distillation.models.huggingface.HuggingFaceStudentHTC
- class snorkelflow_extensions.taxonomy_distillation.models.huggingface.HuggingFaceStudentHTC(taxonomy: Taxonomy, config: Dict[str, Any], text_encoder: HuggingfaceTextEncoder, encoding_batch_size: int = 128)
Bases:
StudentHTC
Hugging Face-based student model for hierarchical text classification.
This implementation combines Hugging Face transformer models for text encoding with multi-layer perceptron (MLP) classifiers for hierarchical classification decisions. The model encodes input text using pre-trained transformers, then applies local classifiers at each taxonomy level to make hierarchical predictions.
The architecture supports configurable MLP layers, various PyTorch optimizers, early stopping, and class weight balancing to handle imbalanced training datasets effectively. Designed for efficient inference after training on teacher-generated labels.
- __init__(taxonomy: Taxonomy, config: Dict[str, Any], text_encoder: HuggingfaceTextEncoder, encoding_batch_size: int = 128) None
Initialize the HuggingFaceStudentHTC model. HuggingFaceStudentHTC models expect a config dictionary with the following keys: {
- “model”: {
“hidden_dim”: <hidden dimension of MLP>, “num_layers”: <number of layers in MLP>, “dropout_rate”: <dropout rate of MLP>
}, “optimizer”: {
“name”: <torch optimizer name>, “params”: {<dictionary of kwargs for torch optimizer>}
}, “train”: {
“batch_size”: <batch size>, “num_epochs”: <maximum number of epochs>, “early_stopping_patience”: <number of epochs to wait before early stopping>, “train_validation_split”: <fraction of data to use for training>, “balance_class_weights”: <whether to balance class weights>
}, “valid”: {
“batch_size”: <batch size>
}
} :param taxonomy: The taxonomy data class instance to use. :param config: The configuration dictionary. :param text_encoder: The text encoder to use. :param encoding_batch_size: The batch size to use for encoding texts. Default is 128.
\_\_init\_\_
__init__
Methods
__init__
(taxonomy, config, text_encoder[, ...])Initialize the HuggingFaceStudentHTC model. HuggingFaceStudentHTC models expect a config dictionary with the following keys: { "model": { "hidden_dim": <hidden dimension of MLP>, "num_layers": <number of layers in MLP>, "dropout_rate": <dropout rate of MLP> }, "optimizer": { "name": <torch optimizer name>, "params": {<dictionary of kwargs for torch optimizer>} }, "train": { "batch_size": <batch size>, "num_epochs": <maximum number of epochs>, "early_stopping_patience": <number of epochs to wait before early stopping>, "train_validation_split": <fraction of data to use for training>, "balance_class_weights": <whether to balance class weights> }, "valid": { "batch_size": <batch size> } } :param taxonomy: The taxonomy data class instance to use. :param config: The configuration dictionary. :param text_encoder: The text encoder to use. :param encoding_batch_size: The batch size to use for encoding texts. Default is 128. classify_text
(text)Classify the text. classify_texts
(texts)Classify a list of texts. fit
(teacher_labeled_data)Train the student model local classifiers. get_property_labeled_data
(labeled_data, prop)Get the property classifications for a given property. get_subtaxonomy_labeled_data
(labeled_data, ...)Get the labeled data for a given subtaxonomy. init_classifiers
()Initialize the classifiers. BFS through the taxonomy to initialize the classifiers. classifiers has the following format: .. code-block:: {<class_name>: { "properties": { <property_name>: { "categories": { <class_name>: { "properties": {...} } }, "classifier_index_to_class_name": <list of class names>, "classifier": <classifier> } } }} Returns: None. load
(load_path)Load the student model. save
(save_path)Save the student model. split_features_labels
(features, labels, split)Split the dataset into training and validation datasets. - classify_text(text: str) Dict[str, Any]
Classify the text. :param text: The text to classify.
Returns: The classification result.
classify\_text
classify_text
- classify_texts(texts: List[str]) List[Dict[str, Any]]
Classify a list of texts. :param texts: The texts to classify.
Returns: The classification results.
classify\_texts
classify_texts
- fit(teacher_labeled_data: Dict[str, Any]) Dict[str, Any]
Train the student model local classifiers. Breadth first search traversal through the taxonomy to train the classifiers. Return a dictionary of the training results. :param teacher_labeled_data: The labeled data from the teacher model.
Returns: The training results. The results dictionary has the following format: {
- taxonomy.name: {
- “properties”: {
- <property_name>: {
“training_dataset_size”: <training_dataset_size>, “validation_dataset_size”: <validation_dataset_size>, “training_accuracy”: <training_accuracy>, “training_macro_f1”: <training_macro_f1>, “validation_accuracy”: <validation_accuracy> “validation_macro_f1”: <validation_macro_f1> “categories”: {
- <class_name>: {
- “properties”: {
- <subproperty_name>: {
“training_dataset_size”: <training_dataset_size>, “validation_dataset_size”: <validation_dataset_size>, “training_accuracy”: <training_accuracy>, “training_macro_f1”: <training_macro_f1>, “validation_accuracy”: <validation_accuracy> “validation_macro_f1”: <validation_macro_f1>
}
}, “class_count”: <class_count>,
}
}
}
}, “class_count”: <class_count>,
} “setup_runtime”: <setup_runtime>, “training_runtime”: <training_runtime>
}
fit
fit
- load(load_path: str) None
Load the student model. :param load_path: The path to load the model.
Returns: None
load
load
- save(save_path: str) None
Save the student model. This method is not implemented. :param save_path: The path to save the model.
Returns: None
save
save
- split_features_labels(features: Tensor, labels: Tensor, split: float) Tuple[Tuple[Tensor, Tensor], Tuple[Tensor, Tensor]]
Split the dataset into training and validation datasets. :param features: The features tensor. :param labels: The labels tensor. :param split: The fraction of the data to use for training.
Returns: A tuple of the training and validation datasets.
split\_features\_labels
split_features_labels