add /detect route for language detection
This commit is contained in:
parent
4ee081df66
commit
4309b55aa4
|
@ -1,4 +1,5 @@
|
|||
use anyhow::Error;
|
||||
use std::collections::HashMap;
|
||||
use silicon::formatter::{ImageFormatter, ImageFormatterBuilder};
|
||||
use silicon::utils::{Background, ShadowAdder};
|
||||
use std::path::PathBuf;
|
||||
|
@ -11,6 +12,15 @@ use crate::rgba::{ImageRgba, Rgba};
|
|||
type FontList = Vec<(String, f32)>;
|
||||
type Lines = Vec<u32>;
|
||||
|
||||
macro_rules! unwrap_or_return {
|
||||
( $e:expr, $r:expr ) => {
|
||||
match $e {
|
||||
Ok(x) => x,
|
||||
Err(_) => return $r,
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
#[derive(Debug, serde::Deserialize)]
|
||||
pub struct Config {
|
||||
/// Background image URL
|
||||
|
@ -132,15 +142,25 @@ impl Config {
|
|||
ps.find_syntax_by_first_line(first_line).unwrap_or_else(|| {
|
||||
// Try using tensorflow to detect the language
|
||||
let input_data = Tensor::new(&[1]).with_values(&[self.code.clone()]).unwrap();
|
||||
self.predict_language_with_tensorflow(ps, input_data)
|
||||
.unwrap_or_else(|_| ps.find_syntax_by_token("log").unwrap())
|
||||
let predictions = self.predict_language_with_tensorflow(ps, input_data).unwrap();
|
||||
|
||||
let mut max_score = -std::f32::INFINITY;
|
||||
let mut max_language = "log";
|
||||
for (language, score) in &predictions { // Borrow predictions here
|
||||
if *score > max_score {
|
||||
max_score = *score;
|
||||
max_language = language;
|
||||
}
|
||||
}
|
||||
|
||||
ps.find_syntax_by_token(max_language).unwrap_or_else(|| ps.find_syntax_by_token("log").unwrap())
|
||||
})
|
||||
},
|
||||
};
|
||||
Ok(language)
|
||||
}
|
||||
|
||||
pub fn predict_language_with_tensorflow<'a>(&self, ps: &'a SyntaxSet, input_data: Tensor<String>) -> Result<&'a SyntaxReference, Error> {
|
||||
pub fn predict_language_with_tensorflow<'a>(&self, ps: &'a SyntaxSet, input_data: Tensor<String>) -> Result<HashMap<String, f32>, Error> {
|
||||
if self.tf_model_graph.is_none() || self.tf_model.is_none() {
|
||||
return Err(Error::msg("TensorFlow model not loaded"));
|
||||
}
|
||||
|
@ -165,14 +185,14 @@ impl Config {
|
|||
|
||||
let classes: Tensor<String> = args.fetch(output_token_classes)?;
|
||||
|
||||
// Find the index of the highest score
|
||||
let max_index = scores.iter().enumerate().max_by(|a, b| a.1.partial_cmp(b.1).unwrap()).unwrap().0;
|
||||
let mut result: HashMap<String, f32> = HashMap::new();
|
||||
for (i, score) in scores.iter().enumerate() {
|
||||
let class = classes[i].clone();
|
||||
let log_score = score.log2();
|
||||
result.insert(class, log_score);
|
||||
}
|
||||
|
||||
|
||||
let language = &classes[max_index];
|
||||
let language = ps.find_syntax_by_token(language).unwrap();
|
||||
|
||||
Ok(language)
|
||||
Ok(result)
|
||||
}
|
||||
|
||||
|
||||
|
|
62
src/main.rs
62
src/main.rs
|
@ -7,6 +7,7 @@ use anyhow::Error;
|
|||
use lazy_static::lazy_static;
|
||||
use silicon as si;
|
||||
use silicon::utils::ToRgba;
|
||||
use tensorflow::Tensor;
|
||||
use std::collections::HashSet;
|
||||
use std::io::Cursor;
|
||||
use std::num::ParseIntError;
|
||||
|
@ -20,6 +21,7 @@ lazy_static! {
|
|||
static ref HIGHLIGHTING_ASSETS: si::assets::HighlightingAssets =
|
||||
silicon::assets::HighlightingAssets::new();
|
||||
}
|
||||
|
||||
macro_rules! unwrap_or_return {
|
||||
( $e:expr, $r:expr ) => {
|
||||
match $e {
|
||||
|
@ -87,6 +89,12 @@ async fn help() -> impl Responder {
|
|||
"GET /themes": "Return a list of available syntax themes.",
|
||||
"GET /languages": "Retuns a list of languages which can be parsed.",
|
||||
"GET /fonts": "Returns a list of available fonts.",
|
||||
"GET /detect": {
|
||||
"description": "Detect the language of the given code.",
|
||||
"parameters": {
|
||||
"code": "The code to detect the language of. Required."
|
||||
}
|
||||
},
|
||||
"GET /generate": {
|
||||
"description": "Generate an image from the given code.",
|
||||
"parameters": {
|
||||
|
@ -150,6 +158,59 @@ async fn fonts() -> impl Responder {
|
|||
HttpResponse::Ok().json(fonts)
|
||||
}
|
||||
|
||||
#[get("/detect")]
|
||||
async fn detect(info: web::Query<config::ConfigQuery>) -> impl Responder {
|
||||
let args = CliArgs::parse();
|
||||
let ha = &*HIGHLIGHTING_ASSETS;
|
||||
|
||||
let (ps, _ts) = (&ha.syntax_set, &ha.theme_set);
|
||||
|
||||
let mut conf = config::Config::default();
|
||||
conf.code = info.code.clone();
|
||||
if conf.code.is_empty() {
|
||||
return HttpResponse::BadRequest()
|
||||
.append_header(("Content-Type", "application/json"))
|
||||
.body(r#"{"error": "code parameter is required"}"#);
|
||||
}
|
||||
|
||||
if args.tensorflow_model_dir.is_some() {
|
||||
conf.load_tensorflow_model(args.tensorflow_model_dir.unwrap().as_str());
|
||||
}
|
||||
|
||||
let input_data = Tensor::new(&[1]).with_values(&[conf.code.clone()]).unwrap();
|
||||
let predictions = unwrap_or_return!(
|
||||
conf.predict_language_with_tensorflow(ps, input_data),
|
||||
HttpResponse::BadRequest()
|
||||
.append_header(("Content-Type", "application/json"))
|
||||
.body(r#"{"error": "Failed to detect language."}"#)
|
||||
);
|
||||
|
||||
let mut sorted_predictions: Vec<_> = predictions.iter().collect();
|
||||
sorted_predictions.sort_by(|a, b| b.1.partial_cmp(a.1).unwrap());
|
||||
|
||||
let min_score = predictions.iter().map(|(_, score)| *score).fold(f32::INFINITY, f32::min);
|
||||
let max_score = predictions.iter().map(|(_, score)| *score).fold(f32::NEG_INFINITY, f32::max);
|
||||
|
||||
// Normalize scores and pick top 5
|
||||
let mut normalized_predictions: Vec<_> = predictions.iter().map(|(lang, score)| {
|
||||
let normalized_score = (score - min_score) / (max_score - min_score) * 100.0;
|
||||
(lang, normalized_score)
|
||||
}).collect();
|
||||
|
||||
normalized_predictions.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
|
||||
|
||||
let response = normalized_predictions
|
||||
.iter()
|
||||
// .take(5)
|
||||
.map(|(language, score)| format!("{{\"language\": \"{}\", \"score\": {}}}", language, score))
|
||||
.collect::<Vec<_>>()
|
||||
.join(",");
|
||||
|
||||
HttpResponse::Ok()
|
||||
.append_header(("Content-Type", "application/json"))
|
||||
.body(format!("[{}]", response))
|
||||
}
|
||||
|
||||
#[get("/generate")]
|
||||
async fn generate(info: web::Query<config::ConfigQuery>) -> impl Responder {
|
||||
let args = CliArgs::parse();
|
||||
|
@ -290,6 +351,7 @@ async fn main() -> std::io::Result<()> {
|
|||
.service(themes)
|
||||
.service(languages)
|
||||
.service(fonts)
|
||||
.service(detect)
|
||||
.service(generate)
|
||||
})
|
||||
.bind((host.clone(), port.parse::<u16>().unwrap()))?
|
||||
|
|
Loading…
Reference in New Issue