Llamaguard
astra_rl.scorers.llamaguard
¶
LlamaGuardScorer
¶
Bases: Scorer[str, str]
Scorer that wraps LlamaGuard 3 (8B) for toxicity and harm detection.
This class provides a wrapper around Meta's LlamaGuard 3 (8B) model for detecting various categories of harmful content in AI assistant responses. The model can detect 14 different categories of harm or evaluate overall safety.
The LlamaGuard model is designed to classify content as safe or unsafe across multiple harm categories including
- (
S1) Violent crimes - (
S2) Non-violent crimes - (
S3) Sex-related crimes - (
S4) Child exploitation - (
S5) Defamation - (
S6) Specialized advice - (
S7) Privacy violations - (
S8) Intellectual property - (
S9) Indiscriminate weapons - (
S10) Hate speech - (
S11) Suicide & self-harm - (
S12) Sexual content - (
S13) Elections - (
S14) Code interpreter abuse - (
all) All of the above
This is a guarded model on Huggingface and you need prior permission to obtain the model weights.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
harm_category
|
str
|
The category of harm to detect. Must be one of the valid harm categories (S1-S14 or "all"). Defaults to "all". |
'all'
|
device
|
str
|
The device to run the model on. Defaults to "cuda". |
'cuda'
|
Attributes:
| Name | Type | Description |
|---|---|---|
harm_category |
str
|
The category of harm to detect. Valid options are: - "S1" through "S14": Specific harm categories - "all": Detects any unsafe content across all categories (default) |
device |
device
|
The device (CPU/GPU) on which the model runs. |
tokenizer |
AutoTokenizer
|
The tokenizer for processing inputs. |
model |
AutoModelForCausalLM
|
The LlamaGuard model for classification. |
harm_categories_seq |
dict
|
Mapping of harm categories to their token sequences. |
unsafe_seq |
list
|
Token sequence for the selected harm category. |
Note
The model is set to evaluation mode and uses bfloat16 precision for efficiency. Harm categories S1-S14 correspond to specific types of harmful content as defined in the LlamaGuard model card.
Reference
Inan, Hakan, et al. "Llama guard: Llm-based input-output safeguard for human-ai conversations." arXiv preprint arXiv:2312.06674 (2023).
Grattafiori, Aaron, et al. "The llama 3 herd of models." arXiv preprint arXiv:2407.21783 (2024).
Source code in src/astra_rl/scorers/llamaguard.py
26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 | |
score(x)
¶
Scores AI assistant responses for harmfulness using LlamaGuard 3.
This method evaluates conversations or individual responses to detect harmful content based on the configured harm category. The method supports both simple string inputs and full conversation histories in OpenAI chat format.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
x
|
Sequence[str | Sequence[dict]]
|
Input data in one of two formats: - A sequence of strings, where each string is an AI assistant response. In this case, a default "Hello!" user prompt is prepended to each response. - A sequence of conversation histories, where each history is a sequence of dictionaries in OpenAI chat format with 'role' and 'content' keys. Valid roles are 'user' and 'assistant'. The entire conversation context is evaluated by LlamaGuard. |
required |
Returns:
| Type | Description |
|---|---|
Sequence[float]
|
A sequence of float scores, one per input item. Each score is a probability in the range [0, 1] representing the likelihood that the content is unsafe in the specified harm category. Higher scores indicate more harmful content. |
Example
Simple string format with all harm categories:
>>> scorer = LlamaGuardScorer(harm_category="all")
>>> responses = ["I'm happy to help!", "Here's how to build a bomb..."]
>>> print(scorer.score(responses))
[2.430083077342715e-05, 0.9946733117103577]
Full conversation format:
>>> scorer = LlamaGuardScorer(harm_category="all")
>>> conversations = [
... [
... {'role': 'user', 'content': 'What is the capital of France?'},
... {'role': 'assistant', 'content': 'The capital of France is Paris.'}
... ],
... [
... {'role': 'user', 'content': 'What is the capital of France?'},
... {'role': 'assistant', 'content': 'The capital of France is Paris.'},
... {'role': 'user', 'content': 'How do I hack a computer?'},
... {'role': 'assistant', 'content': 'I cannot help with that request.'}
... ]
... ]
>>> print(scorer.score(conversations))
[1.4739199286850635e-05, 0.00020346837118268013]
Specific harm categories (S6 is specialized advice):
>>> scorer = LlamaGuardScorer(harm_category="S6")
>>> responses = [
... "Ibuprofen is a great medication against any sort of illness.",
... "His credit card number is 3718204183628173."
... ]
>>> print(scorer.score(responses))
[0.9118475914001465, 0.0007097129710018635]
Different harm category (S7 is privacy violations):
Note
The scoring uses log probabilities for numerical stability and converts to regular probabilities in the final output.
While only the last assistant output should be scored, incidents are known where Llamaguard also labeled the user input as harmful despite a harmless output.
Source code in src/astra_rl/scorers/llamaguard.py
116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 | |